Full Code of VainF/Diff-Pruning for AI

main 71e922e54275 cached
439 files
8.8 MB
2.3M tokens
3841 symbols
1 requests
Download .txt
Showing preview only (9,287K chars total). Download the full file or copy to clipboard to get everything.
Repository: VainF/Diff-Pruning
Branch: main
Commit: 71e922e54275
Files: 439
Total size: 8.8 MB

Directory structure:
gitextract_bsd98edx/

├── .gitignore
├── LICENSE
├── README.md
├── ddpm_exp/
│   ├── .gitignore
│   ├── LICENSE
│   ├── README.md
│   ├── calc_fid.py
│   ├── compute_flops.py
│   ├── compute_pruned_ssim_curve.py
│   ├── compute_ssim.py
│   ├── compute_ssim_vis.py
│   ├── configs/
│   │   ├── bedroom.yml
│   │   ├── celeba.yml
│   │   ├── church.yml
│   │   ├── cifar10.yml
│   │   └── cifar10_pruning.yml
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── celeba.py
│   │   ├── ffhq.py
│   │   ├── lsun.py
│   │   ├── utils.py
│   │   └── vision.py
│   ├── draw_ssim_pruned_curve.py
│   ├── extract_cifar10.py
│   ├── fid_score.py
│   ├── finetune.py
│   ├── finetune_simple.py
│   ├── functions/
│   │   ├── __init__.py
│   │   ├── ckpt_util.py
│   │   ├── denoising.py
│   │   └── losses.py
│   ├── inception.py
│   ├── main.py
│   ├── models/
│   │   ├── diffusion.py
│   │   └── ema.py
│   ├── prune.py
│   ├── prune_kd.py
│   ├── prune_ssim.py
│   ├── prune_test.py
│   ├── runners/
│   │   ├── __init__.py
│   │   ├── diffusion.py
│   │   └── diffusion_simple.py
│   ├── scripts/
│   │   ├── finetune_bedroom_ddpm.sh
│   │   ├── finetune_celeba_ddpm.sh
│   │   ├── finetune_celeba_ddpm_kd.sh
│   │   ├── finetune_church_ddpm.sh
│   │   ├── finetune_cifar_ddpm.sh
│   │   ├── finetune_cifar_ddpm_kd.sh
│   │   ├── finetune_cifar_ddpm_random.sh
│   │   ├── finetune_cifar_ddpm_taylor.sh
│   │   ├── old/
│   │   │   ├── run_bedroom_sample_pratrained.sh
│   │   │   ├── run_celeba_pruning_scratch.sh
│   │   │   ├── run_celeba_pruning_taylor.sh
│   │   │   ├── run_celeba_sample_pratrained.sh
│   │   │   ├── run_church_pruning_taylor.sh
│   │   │   ├── run_cifar_pruning_first_order_taylor.sh
│   │   │   ├── run_cifar_pruning_magnitude.sh
│   │   │   ├── run_cifar_pruning_random.sh
│   │   │   ├── run_cifar_pruning_random_kd.sh
│   │   │   ├── run_cifar_pruning_scratch.sh
│   │   │   ├── run_cifar_pruning_second_order_taylor.sh
│   │   │   ├── run_cifar_pruning_taylor.sh
│   │   │   ├── run_cifar_pruning_taylor_kd.sh
│   │   │   └── run_cifar_train.sh
│   │   ├── prune_bedroom_ddpm.sh
│   │   ├── prune_bedroom_ddpm_test.sh
│   │   ├── prune_celeba_ddpm.sh
│   │   ├── prune_celeba_ddpm_ssim.sh
│   │   ├── prune_church_ddpm.sh
│   │   ├── prune_church_ddpm_test.sh
│   │   ├── prune_cifar_ddpm.sh
│   │   ├── prune_cifar_ddpm_ssim.sh
│   │   ├── prune_cifar_ddpm_test.sh
│   │   ├── run_celeba.sh
│   │   ├── sample_bedroom_ddpm_pretrained.sh
│   │   ├── sample_bedroom_ddpm_pruning.sh
│   │   ├── sample_celeba_ddpm_pruning.sh
│   │   ├── sample_celeba_pretrained.sh
│   │   ├── sample_church_ddpm_pruning.sh
│   │   ├── sample_church_ddpm_pruning_old.sh
│   │   ├── sample_church_ddpm_test.sh
│   │   ├── sample_church_pretrained.sh
│   │   ├── sample_cifar_ddpm_pruning.sh
│   │   ├── sample_cifar_pretrained.sh
│   │   ├── simple_celeba_our.sh
│   │   └── simple_cifar_our.sh
│   ├── tools/
│   │   ├── extract_cifar10.py
│   │   └── transform_weights.py
│   ├── torch_pruning/
│   │   ├── __init__.py
│   │   ├── _helpers.py
│   │   ├── dependency.py
│   │   ├── importance.py
│   │   ├── ops.py
│   │   ├── pruner/
│   │   │   ├── __init__.py
│   │   │   ├── algorithms/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── batchnorm_scale_pruner.py
│   │   │   │   ├── group_norm_pruner.py
│   │   │   │   ├── magnitude_based_pruner.py
│   │   │   │   ├── metapruner.py
│   │   │   │   ├── scaling_factor_pruner.py
│   │   │   │   ├── scheduler.py
│   │   │   │   └── taylor_pruner.py
│   │   │   └── function.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── op_counter.py
│   │       └── utils.py
│   └── utils.py
├── ddpm_prune.py
├── ddpm_sample.py
├── ddpm_train.py
├── diffusers/
│   ├── __init__.py
│   ├── commands/
│   │   ├── __init__.py
│   │   ├── diffusers_cli.py
│   │   └── env.py
│   ├── configuration_utils.py
│   ├── dependency_versions_check.py
│   ├── dependency_versions_table.py
│   ├── experimental/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   └── rl/
│   │       ├── __init__.py
│   │       └── value_guided_sampling.py
│   ├── image_processor.py
│   ├── loaders.py
│   ├── models/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── attention_flax.py
│   │   ├── attention_processor.py
│   │   ├── autoencoder_kl.py
│   │   ├── controlnet.py
│   │   ├── controlnet_flax.py
│   │   ├── cross_attention.py
│   │   ├── dual_transformer_2d.py
│   │   ├── embeddings.py
│   │   ├── embeddings_flax.py
│   │   ├── modeling_flax_pytorch_utils.py
│   │   ├── modeling_flax_utils.py
│   │   ├── modeling_pytorch_flax_utils.py
│   │   ├── modeling_utils.py
│   │   ├── prior_transformer.py
│   │   ├── resnet.py
│   │   ├── resnet_flax.py
│   │   ├── t5_film_transformer.py
│   │   ├── transformer_2d.py
│   │   ├── transformer_temporal.py
│   │   ├── unet_1d.py
│   │   ├── unet_1d_blocks.py
│   │   ├── unet_2d.py
│   │   ├── unet_2d_blocks.py
│   │   ├── unet_2d_blocks_flax.py
│   │   ├── unet_2d_condition.py
│   │   ├── unet_2d_condition_flax.py
│   │   ├── unet_3d_blocks.py
│   │   ├── unet_3d_condition.py
│   │   ├── vae.py
│   │   ├── vae_flax.py
│   │   └── vq_model.py
│   ├── optimization.py
│   ├── pipeline_utils.py
│   ├── pipelines/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── alt_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── modeling_roberta_series.py
│   │   │   ├── pipeline_alt_diffusion.py
│   │   │   └── pipeline_alt_diffusion_img2img.py
│   │   ├── audio_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── mel.py
│   │   │   └── pipeline_audio_diffusion.py
│   │   ├── audioldm/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_audioldm.py
│   │   ├── controlnet/
│   │   │   ├── __init__.py
│   │   │   ├── multicontrolnet.py
│   │   │   ├── pipeline_controlnet.py
│   │   │   ├── pipeline_controlnet_img2img.py
│   │   │   ├── pipeline_controlnet_inpaint.py
│   │   │   └── pipeline_flax_controlnet.py
│   │   ├── dance_diffusion/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_dance_diffusion.py
│   │   ├── ddim/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_ddim.py
│   │   ├── ddpm/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_ddpm.py
│   │   ├── deepfloyd_if/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_if.py
│   │   │   ├── pipeline_if_img2img.py
│   │   │   ├── pipeline_if_img2img_superresolution.py
│   │   │   ├── pipeline_if_inpainting.py
│   │   │   ├── pipeline_if_inpainting_superresolution.py
│   │   │   ├── pipeline_if_superresolution.py
│   │   │   ├── safety_checker.py
│   │   │   ├── timesteps.py
│   │   │   └── watermark.py
│   │   ├── dit/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_dit.py
│   │   ├── latent_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_latent_diffusion.py
│   │   │   └── pipeline_latent_diffusion_superresolution.py
│   │   ├── latent_diffusion_uncond/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_latent_diffusion_uncond.py
│   │   ├── onnx_utils.py
│   │   ├── paint_by_example/
│   │   │   ├── __init__.py
│   │   │   ├── image_encoder.py
│   │   │   └── pipeline_paint_by_example.py
│   │   ├── pipeline_flax_utils.py
│   │   ├── pipeline_utils.py
│   │   ├── pndm/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_pndm.py
│   │   ├── repaint/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_repaint.py
│   │   ├── score_sde_ve/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_score_sde_ve.py
│   │   ├── semantic_stable_diffusion/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_semantic_stable_diffusion.py
│   │   ├── spectrogram_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── continous_encoder.py
│   │   │   ├── midi_utils.py
│   │   │   ├── notes_encoder.py
│   │   │   └── pipeline_spectrogram_diffusion.py
│   │   ├── stable_diffusion/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── convert_from_ckpt.py
│   │   │   ├── pipeline_cycle_diffusion.py
│   │   │   ├── pipeline_flax_stable_diffusion.py
│   │   │   ├── pipeline_flax_stable_diffusion_controlnet.py
│   │   │   ├── pipeline_flax_stable_diffusion_img2img.py
│   │   │   ├── pipeline_flax_stable_diffusion_inpaint.py
│   │   │   ├── pipeline_onnx_stable_diffusion.py
│   │   │   ├── pipeline_onnx_stable_diffusion_img2img.py
│   │   │   ├── pipeline_onnx_stable_diffusion_inpaint.py
│   │   │   ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py
│   │   │   ├── pipeline_onnx_stable_diffusion_upscale.py
│   │   │   ├── pipeline_stable_diffusion.py
│   │   │   ├── pipeline_stable_diffusion_attend_and_excite.py
│   │   │   ├── pipeline_stable_diffusion_controlnet.py
│   │   │   ├── pipeline_stable_diffusion_depth2img.py
│   │   │   ├── pipeline_stable_diffusion_diffedit.py
│   │   │   ├── pipeline_stable_diffusion_image_variation.py
│   │   │   ├── pipeline_stable_diffusion_img2img.py
│   │   │   ├── pipeline_stable_diffusion_inpaint.py
│   │   │   ├── pipeline_stable_diffusion_inpaint_legacy.py
│   │   │   ├── pipeline_stable_diffusion_instruct_pix2pix.py
│   │   │   ├── pipeline_stable_diffusion_k_diffusion.py
│   │   │   ├── pipeline_stable_diffusion_latent_upscale.py
│   │   │   ├── pipeline_stable_diffusion_model_editing.py
│   │   │   ├── pipeline_stable_diffusion_panorama.py
│   │   │   ├── pipeline_stable_diffusion_pix2pix_zero.py
│   │   │   ├── pipeline_stable_diffusion_sag.py
│   │   │   ├── pipeline_stable_diffusion_upscale.py
│   │   │   ├── pipeline_stable_unclip.py
│   │   │   ├── pipeline_stable_unclip_img2img.py
│   │   │   ├── safety_checker.py
│   │   │   ├── safety_checker_flax.py
│   │   │   └── stable_unclip_image_normalizer.py
│   │   ├── stable_diffusion_safe/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_stable_diffusion_safe.py
│   │   │   └── safety_checker.py
│   │   ├── stochastic_karras_ve/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_stochastic_karras_ve.py
│   │   ├── text_to_video_synthesis/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_text_to_video_synth.py
│   │   │   └── pipeline_text_to_video_zero.py
│   │   ├── unclip/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_unclip.py
│   │   │   ├── pipeline_unclip_image_variation.py
│   │   │   └── text_proj.py
│   │   ├── versatile_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── modeling_text_unet.py
│   │   │   ├── pipeline_versatile_diffusion.py
│   │   │   ├── pipeline_versatile_diffusion_dual_guided.py
│   │   │   ├── pipeline_versatile_diffusion_image_variation.py
│   │   │   └── pipeline_versatile_diffusion_text_to_image.py
│   │   └── vq_diffusion/
│   │       ├── __init__.py
│   │       └── pipeline_vq_diffusion.py
│   ├── schedulers/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── scheduling_ddim.py
│   │   ├── scheduling_ddim_flax.py
│   │   ├── scheduling_ddim_inverse.py
│   │   ├── scheduling_ddpm.py
│   │   ├── scheduling_ddpm_flax.py
│   │   ├── scheduling_deis_multistep.py
│   │   ├── scheduling_dpmsolver_multistep.py
│   │   ├── scheduling_dpmsolver_multistep_flax.py
│   │   ├── scheduling_dpmsolver_multistep_inverse.py
│   │   ├── scheduling_dpmsolver_sde.py
│   │   ├── scheduling_dpmsolver_singlestep.py
│   │   ├── scheduling_euler_ancestral_discrete.py
│   │   ├── scheduling_euler_discrete.py
│   │   ├── scheduling_heun_discrete.py
│   │   ├── scheduling_ipndm.py
│   │   ├── scheduling_k_dpm_2_ancestral_discrete.py
│   │   ├── scheduling_k_dpm_2_discrete.py
│   │   ├── scheduling_karras_ve.py
│   │   ├── scheduling_karras_ve_flax.py
│   │   ├── scheduling_lms_discrete.py
│   │   ├── scheduling_lms_discrete_flax.py
│   │   ├── scheduling_pndm.py
│   │   ├── scheduling_pndm_flax.py
│   │   ├── scheduling_repaint.py
│   │   ├── scheduling_sde_ve.py
│   │   ├── scheduling_sde_ve_flax.py
│   │   ├── scheduling_sde_vp.py
│   │   ├── scheduling_unclip.py
│   │   ├── scheduling_unipc_multistep.py
│   │   ├── scheduling_utils.py
│   │   ├── scheduling_utils_flax.py
│   │   └── scheduling_vq_diffusion.py
│   ├── training_utils.py
│   └── utils/
│       ├── __init__.py
│       ├── accelerate_utils.py
│       ├── constants.py
│       ├── deprecation_utils.py
│       ├── doc_utils.py
│       ├── dummy_flax_and_transformers_objects.py
│       ├── dummy_flax_objects.py
│       ├── dummy_note_seq_objects.py
│       ├── dummy_onnx_objects.py
│       ├── dummy_pt_objects.py
│       ├── dummy_torch_and_librosa_objects.py
│       ├── dummy_torch_and_scipy_objects.py
│       ├── dummy_torch_and_torchsde_objects.py
│       ├── dummy_torch_and_transformers_and_k_diffusion_objects.py
│       ├── dummy_torch_and_transformers_and_onnx_objects.py
│       ├── dummy_torch_and_transformers_objects.py
│       ├── dummy_transformers_and_torch_and_note_seq_objects.py
│       ├── dynamic_modules_utils.py
│       ├── hub_utils.py
│       ├── import_utils.py
│       ├── logging.py
│       ├── model_card_template.md
│       ├── outputs.py
│       ├── pil_utils.py
│       ├── testing_utils.py
│       └── torch_utils.py
├── fid_score.py
├── inception.py
├── ldm_exp/
│   ├── LICENSE
│   ├── README.md
│   ├── configs/
│   │   ├── autoencoder/
│   │   │   ├── autoencoder_kl_16x16x16.yaml
│   │   │   ├── autoencoder_kl_32x32x4.yaml
│   │   │   ├── autoencoder_kl_64x64x3.yaml
│   │   │   └── autoencoder_kl_8x8x64.yaml
│   │   ├── latent-diffusion/
│   │   │   ├── celebahq-ldm-vq-4.yaml
│   │   │   ├── cin-ldm-vq-f8.yaml
│   │   │   ├── cin256-v2.yaml
│   │   │   ├── ffhq-ldm-vq-4.yaml
│   │   │   ├── lsun_bedrooms-ldm-vq-4.yaml
│   │   │   ├── lsun_churches-ldm-kl-8.yaml
│   │   │   └── txt2img-1p4B-eval.yaml
│   │   └── retrieval-augmented-diffusion/
│   │       └── 768x768.yaml
│   ├── environment.yaml
│   ├── fid_score.py
│   ├── inception.py
│   ├── ldm/
│   │   ├── lr_scheduler.py
│   │   ├── models/
│   │   │   ├── autoencoder.py
│   │   │   └── diffusion/
│   │   │       ├── __init__.py
│   │   │       ├── classifier.py
│   │   │       ├── ddim.py
│   │   │       ├── ddpm.py
│   │   │       └── plms.py
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   ├── attention.py
│   │   │   ├── diffusionmodules/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── model.py
│   │   │   │   ├── openaimodel.py
│   │   │   │   └── util.py
│   │   │   ├── distributions/
│   │   │   │   ├── __init__.py
│   │   │   │   └── distributions.py
│   │   │   ├── ema.py
│   │   │   ├── encoders/
│   │   │   │   ├── __init__.py
│   │   │   │   └── modules.py
│   │   │   ├── image_degradation/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── bsrgan.py
│   │   │   │   ├── bsrgan_light.py
│   │   │   │   └── utils_image.py
│   │   │   ├── losses/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── contperceptual.py
│   │   │   │   └── vqperceptual.py
│   │   │   └── x_transformer.py
│   │   └── util.py
│   ├── main.py
│   ├── models/
│   │   ├── first_stage_models/
│   │   │   ├── kl-f16/
│   │   │   │   └── config.yaml
│   │   │   ├── kl-f32/
│   │   │   │   └── config.yaml
│   │   │   ├── kl-f4/
│   │   │   │   └── config.yaml
│   │   │   ├── kl-f8/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f16/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f4/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f4-noattn/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f8/
│   │   │   │   └── config.yaml
│   │   │   └── vq-f8-n256/
│   │   │       └── config.yaml
│   │   └── ldm/
│   │       ├── bsr_sr/
│   │       │   └── config.yaml
│   │       ├── celeba256/
│   │       │   └── config.yaml
│   │       ├── cin256/
│   │       │   └── config.yaml
│   │       ├── ffhq256/
│   │       │   └── config.yaml
│   │       ├── inpainting_big/
│   │       │   └── config.yaml
│   │       ├── layout2img-openimages256/
│   │       │   └── config.yaml
│   │       ├── lsun_beds256/
│   │       │   └── config.yaml
│   │       ├── lsun_churches256/
│   │       │   └── config.yaml
│   │       ├── semantic_synthesis256/
│   │       │   └── config.yaml
│   │       ├── semantic_synthesis512/
│   │       │   └── config.yaml
│   │       └── text2img256/
│   │           └── config.yaml
│   ├── notebook_helpers.py
│   ├── profile_ldm.py
│   ├── profile_ldm_pretrained.py
│   ├── profile_model.py
│   ├── prune_ldm.py
│   ├── prune_ldm_no_grad.py
│   ├── run.sh
│   ├── sample_for_FID.py
│   ├── sample_imagenet.py
│   ├── sample_pruned.py
│   ├── scripts/
│   │   ├── download_first_stages.sh
│   │   ├── download_models.sh
│   │   ├── inpaint.py
│   │   ├── knn2img.py
│   │   ├── latent_imagenet_diffusion.ipynb
│   │   ├── sample_diffusion.py
│   │   ├── train_searcher.py
│   │   └── txt2img.py
│   ├── setup.py
│   ├── test_criterion.py
│   └── test_diffusion.py
├── ldm_prune.py
├── requirements.txt
├── scripts/
│   ├── finetune_ddpm_cifar10.sh
│   ├── prune_ddpm_cifar10.sh
│   ├── prune_ddpm_ema_bedroom_random.sh
│   ├── prune_ddpm_ema_church_random.sh
│   ├── prune_ldm.sh
│   ├── sample_ddpm_cifar10_pretrained.sh
│   ├── sample_ddpm_cifar10_pretrained_distributed.sh
│   └── sample_ddpm_cifar10_pruned.sh
├── tools/
│   ├── convert_cifar10_ddpm_ema.sh
│   ├── convert_ddpm_original_checkpoint_to_diffusers_cifar10.py
│   ├── convert_ldm_original_checkpoint_to_diffusers.py
│   ├── ddpm_cifar10_config.json
│   ├── extract_cifar10.py
│   └── ldm_unet_config.json
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Initially taken from Github's Python gitignore file

run2
run
pretrained

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
checkpoints
*.ckpt
# C extensions
*.so

cache

# tests and logs
tests/fixtures/cached_*_text.txt
logs/
lightning_logs/
lang_code_data/
docker
doc
_typos.toml
data
docs

ldm_generated_image.png
*.log

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# vscode
.vs
.vscode

# Pycharm
.idea

# TF code
tensorflow_code

# Models
proc_data

# examples
runs
/runs_old
/wandb
/examples/runs
/examples/**/*.args
/examples/rag/sweep

# data
/data
serialization_dir

# emacs
*.*~
debug.env

# vim
.*.swp

#ctags
tags

# pre-commit
.pre-commit*

# .lock
*.lock

# DS_Store (MacOS)
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4

# dependencies
/transformers

# ruff
.ruff_cache

wandb

__pycache__

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [2023] [Gongfan Fang]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Diff-Pruning: Structural Pruning for Diffusion Models

<div align="center">
<img src="assets/framework.png" width="80%"></img>
</div>

## Update
Check our latest work [DeepCache](https://horseee.github.io/Diffusion_DeepCache/), a **training-free and almost loessless** method for diffusion model acceleration. It can be viewed as a special pruning technique that dynamically drops deep layers and only runs shallow ones during inference.


## Introduction
> **Structural Pruning for Diffusion Models** [[arxiv]](https://arxiv.org/abs/2305.10924)  
> *[Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)*    
> *National University of Singapore*

This work presents *Diff-Pruning*, an efficient structrual pruning method for diffusion models. Our empirical assessment highlights two primary features:
1) ``Efficiency``: It enables approximately a 50% reduction in FLOPs at a mere 10% to 20% of the original training expenditure; 
2) ``Consistency``: The pruned diffusion models inherently preserve generative behavior congruent with the pre-trained ones.

<div align="center">
<img src="assets/LSUN.png" width="80%"></img>
</div>

### Supported Methods
- [x] Magnitude Pruning
- [x] Random Pruning
- [x] Taylor Pruning
- [x] Diff-Pruning (A taylor-based method proposed in our paper)   

### TODO List
- [ ] Support more diffusion models from Diffusers
- [ ] Upload checkpoints of pruned models
- [ ] Training scripts for CelebA-HQ, LSUN Church & LSUN Bedroom
- [ ] Align the performance with the [DDIM Repo](https://github.com/ermongroup/ddim). 

## Our Exp Code (Unorganized)

### Pruning with DDIM codebase
This example shows how to prune a DDPM model pre-trained on CIFAR-10 using the [DDIM codebase](https://github.com/ermongroup/ddim). Since that [Huggingface Diffusers](https://github.com/huggingface/diffusers) do not support [``skip_type='quad'``](https://github.com/ermongroup/ddim/issues/3) in DDIM, you may get slightly worse FID scores with Diffusers for both pre-trained models (FID=4.5) and pruned models (FID=5.6). We are working on this to implement the quad strategy for Diffusers. For reproducibility, we provide our original **but unorganized** exp code for the paper in [ddpm_exp](ddpm_exp). 

```bash
cd ddpm_exp
# Prune & Finetune
bash scripts/simple_cifar_our.sh 0.05 # the pre-trained model and data will be automatically prepared
# Sampling
bash scripts/sample_cifar_ddpm_pruning.sh run/finetune_simple_v2/cifar10_ours_T=0.05.pth/logs/post_training/ckpt_100000.pth run/sample
```

For FID, please refer to [this section](https://github.com/VainF/Diff-Pruning#4-fid-score).  

Output:
```
Found 49984 files.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:49<00:00,  7.97it/s]
FID:  5.242662673752534
```

### Pruning with LDM codebase

Please check [ldm_exp/run.sh](ldm_exp/run.sh) for an example of pruning a pre-trained LDM model on ImageNet. This codebase is still unorganized. We will clean it up in the future.

## Pruning with Huggingface Diffusers

The following pipeline prunes a pre-trained DDPM on CIFAR-10 with [Huggingface Diffusers](https://github.com/huggingface/diffusers).

### 0. Requirements, Data and Pretrained Model

* Requirements
```bash
pip install -r requirements.txt
```
 
* Data
  
Download and extract CIFAR-10 images to *data/cifar10_images* for training and evaluation.
```bash
python tools/extract_cifar10.py --output data
```
* Pretrained Models
  
The following script will download an official DDPM model and convert it to the format of Huggingface Diffusers. You can find the converted model at *pretrained/ddpm_ema_cifar10*. It is an EMA version of [google/ddpm-cifar10-32](https://huggingface.co/google/ddpm-cifar10-32)
```bash
bash tools/convert_cifar10_ddpm_ema.sh
```

(Optional) You can also download a pre-converted model using wget
```bash
wget https://github.com/VainF/Diff-Pruning/releases/download/v0.0.1/ddpm_ema_cifar10.zip
```

### 1. Pruning
Create a pruned model at *run/pruned/ddpm_cifar10_pruned*
```bash
bash scripts/prune_ddpm_cifar10.sh 0.3  # pruning ratio = 30\%
```

### 2. Finetuning (Post-Training)
Finetune the model and save it at *run/finetuned/ddpm_cifar10_pruned_post_training*
```bash
bash scripts/finetune_ddpm_cifar10.sh
```

### 3. Sampling
**Pruned:** Sample and save images to *run/sample/ddpm_cifar10_pruned*
```bash
bash scripts/sample_ddpm_cifar10_pruned.sh
```

**Pretrained:** Sample and save images to *run/sample/ddpm_cifar10_pretrained*
```bash
bash scripts/sample_ddpm_cifar10_pretrained.sh
```

### 4. FID Score
This script was modified from https://github.com/mseitzer/pytorch-fid. 

```bash
# pre-compute the stats of CIFAR-10 dataset
python fid_score.py --save-stats data/cifar10_images run/fid_stats_cifar10.npz --device cuda:0 --batch-size 256
```

```bash
# Compute the FID score of sampled images
python fid_score.py run/sample/ddpm_cifar10_pruned run/fid_stats_cifar10.npz --device cuda:0 --batch-size 256
```

### 5. (Optional) Distributed Training and Sampling with Accelerate
This project supports distributed training and sampling. 
```bash
python -m torch.distributed.launch --nproc_per_node=8 --master_port 22222 --use_env <ddpm_sample.py|ddpm_train.py> ...
```
A multi-processing example can be found at [scripts/sample_ddpm_cifar10_pretrained_distributed.sh](scripts/sample_ddpm_cifar10_pretrained_distributed.sh).


## Prune Pre-trained DPMs from [HuggingFace Diffusers](https://huggingface.co/models?library=diffusers)

### :rocket: [Denoising Diffusion Probabilistic Models (DDPMs)](https://arxiv.org/abs/2006.11239)
Example: [google/ddpm-ema-bedroom-256](https://huggingface.co/google/ddpm-ema-bedroom-256)
```bash
python ddpm_prune.py \
--dataset "<path/to/imagefoler>" \  
--model_path google/ddpm-ema-bedroom-256 \
--save_path run/pruned/ddpm_ema_bedroom_256_pruned \
--pruning_ratio 0.05 \
--pruner "<random|magnitude|reinit|taylor|diff-pruning>" \
--batch_size 4 \
--thr 0.05 \
--device cuda:0 \
```
The ``dataset`` and ``thr`` arguments only work for taylor & diff-pruning.


### :rocket: [Latent Diffusion Models (LDMs)](https://arxiv.org/abs/2112.10752)
Example: [CompVis/ldm-celebahq-256](https://huggingface.co/CompVis/ldm-celebahq-256)
```bash
python ldm_prune.py \
--model_path CompVis/ldm-celebahq-256 \
--save_path run/pruned/ldm_celeba_pruned \
--pruning_ratio 0.05 \
--pruner "<random|magnitude|reinit>" \
--device cuda:0 \
--batch_size 4 \
```

## Results

* **DDPM on Cifar-10, CelebA and LSUN**

<div align="center">
<img src="assets/exp.png" width="75%"></img>
<img src="https://github.com/VainF/Diff-Pruning/assets/18592211/39b3a7ad-2abb-4934-9ee0-07724029660b" width="75%"></img>
</div>

* **Conditional LDM on ImageNet-1K 256**

We also have some results on Conditional LDM for ImageNet-1K 256x256, where we finetune a pruned LDM for only 4 epochs. Will release the training script soon.

<div align="center">
<img src="https://github.com/VainF/Diff-Pruning/assets/18592211/31dbf489-2ca2-4625-ba54-5a5ff4e4a626" width="75%"></img>
<img src="https://github.com/VainF/Diff-Pruning/assets/18592211/20d546c5-9012-4ba9-80b2-96ed29da7d07" width="85%"></img>
</div>


## Acknowledgement

This project is heavily based on [Diffusers](https://github.com/huggingface/diffusers), [Torch-Pruning](https://github.com/VainF/Torch-Pruning), [pytorch-fid](https://github.com/mseitzer/pytorch-fid). Our experiments were conducted on [ddim](https://github.com/ermongroup/ddim) and [LDM](https://github.com/CompVis/latent-diffusion).

## Citation
If you find this work helpful, please cite:
```
@inproceedings{fang2023structural,
  title={Structural pruning for diffusion models},
  author={Gongfan Fang and Xinyin Ma and Xinchao Wang},
  booktitle={Advances in Neural Information Processing Systems},
  year={2023},
}
```

```
@inproceedings{fang2023depgraph,
  title={Depgraph: Towards any structural pruning},
  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={16091--16101},
  year={2023}
}
```


================================================
FILE: ddpm_exp/.gitignore
================================================
.vscode
__pycache__
*.log
run
data
*.png

================================================
FILE: ddpm_exp/LICENSE
================================================
MIT License

Copyright (c) 2020 Jiaming Song

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: ddpm_exp/README.md
================================================
# Denoising Diffusion Implicit Models (DDIM)

[Jiaming Song](http://tsong.me), [Chenlin Meng](http://cs.stanford.edu/~chenlin) and [Stefano Ermon](http://cs.stanford.edu/~ermon), Stanford

Implements sampling from an implicit model that is trained with the same procedure as [Denoising Diffusion Probabilistic Model](https://hojonathanho.github.io/diffusion/), but costs much less time and compute if you want to sample from it (click image below for a video demo):

<a href="http://www.youtube.com/watch?v=WCKzxoSduJQ" target="_blank">![](http://img.youtube.com/vi/WCKzxoSduJQ/0.jpg)</a>

## **Integration with 🤗 Diffusers library**

DDIM is now also available in 🧨 Diffusers and accesible via the [DDIMPipeline](https://huggingface.co/docs/diffusers/api/pipelines/ddim).
Diffusers allows you to test DDIM in PyTorch in just a couple lines of code.

You can install diffusers as follows:

```
pip install diffusers torch accelerate
```

And then try out the model with just a couple lines of code:

```python
from diffusers import DDIMPipeline

model_id = "google/ddpm-cifar10-32"

# load model and scheduler
ddim = DDIMPipeline.from_pretrained(model_id)

# run pipeline in inference (sample random noise and denoise)
image = ddim(num_inference_steps=50).images[0]

# save image
image.save("ddim_generated_image.png")
```

More DDPM/DDIM models compatible with hte DDIM pipeline can be found directly [on the Hub](https://huggingface.co/models?library=diffusers&sort=downloads&search=ddpm)

To better understand the DDIM scheduler, you can check out [this introductionary google colab](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)

The DDIM scheduler can also be used with more powerful diffusion models such as [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#stable-diffusion-pipelines)

You simply need to [accept the license on the Hub](https://huggingface.co/runwayml/stable-diffusion-v1-5), login with `huggingface-cli login` and install transformers:

```
pip install transformers
```

Then you can run:

```python
from diffusers import StableDiffusionPipeline, DDIMScheduler

ddim = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=ddim)

image = pipeline("An astronaut riding a horse.").images[0]

image.save("astronaut_riding_a_horse.png")
```

## Running the Experiments
The code has been tested on PyTorch 1.6.

### Train a model
Training is exactly the same as DDPM with the following:
```
python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --ni
```

### Sampling from the model

#### Sampling from the generalized model for FID evaluation
```
python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --sample --fid --timesteps {STEPS} --eta {ETA} --ni
```
where 
- `ETA` controls the scale of the variance (0 is DDIM, and 1 is one type of DDPM).
- `STEPS` controls how many timesteps used in the process.
- `MODEL_NAME` finds the pre-trained checkpoint according to its inferred path.

If you want to use the DDPM pretrained model:
```
python main.py --config {DATASET}.yml --exp {PROJECT_PATH} --use_pretrained --sample --fid --timesteps {STEPS} --eta {ETA} --ni
```
the `--use_pretrained` option will automatically load the model according to the dataset.

We provide a CelebA 64x64 model [here](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing), and use the DDPM version for CIFAR10 and LSUN.

If you want to use the version with the larger variance in DDPM: use the `--sample_type ddpm_noisy` option.

#### Sampling from the model for image inpainting 
Use `--interpolation` option instead of `--fid`.

#### Sampling from the sequence of images that lead to the sample
Use `--sequence` option instead.

The above two cases contain some hard-coded lines specific to producing the image, so modify them according to your needs.


## References and Acknowledgements
```
@article{song2020denoising,
  title={Denoising Diffusion Implicit Models},
  author={Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
  journal={arXiv:2010.02502},
  year={2020},
  month={October},
  abbr={Preprint},
  url={https://arxiv.org/abs/2010.02502}
}
```


This implementation is based on / inspired by:

- [https://github.com/hojonathanho/diffusion](https://github.com/hojonathanho/diffusion) (the DDPM TensorFlow repo), 
- [https://github.com/pesser/pytorch_diffusion](https://github.com/pesser/pytorch_diffusion) (PyTorch helper that loads the DDPM model), and
- [https://github.com/ermongroup/ncsnv2](https://github.com/ermongroup/ncsnv2) (code structure).


================================================
FILE: ddpm_exp/calc_fid.py
================================================
from cleanfid import fid
import argparse
parser = argparse.ArgumentParser(description=globals()["__doc__"])
parser.add_argument('--path1', type=str, required=True, help='Path to the images')
parser.add_argument('--path2', type=str, required=True, help='Path to the images')
args = parser.parse_args()

if args.path2=="cifar10":
    score = fid.compute_fid(args.dir, dataset_name="cifar10", dataset_res=32, dataset_split="train")
else:
    score = fid.compute_fid(args.path1, args.path2)
print("FID: ", score)

================================================
FILE: ddpm_exp/compute_flops.py
================================================
import torch
import random, os
import argparse
from PIL import Image
import torchvision
import numpy as np
import pytorch_msssim
from utils import UnlabeledImageFolder
from tqdm import tqdm 
import torch_pruning as tp
parser = argparse.ArgumentParser()
parser.add_argument('--restore_from', type=str, required=True)
args = parser.parse_args()

model = torch.load(args.restore_from, map_location='cpu')[0]
example_inputs = {'x': torch.randn(1, 3, 32, 32), 't': torch.ones(1)}
macs, params = tp.utils.count_ops_and_params(model, example_inputs)
print("model: {}, macs: {} G, params: {} M".format(args.restore_from, macs/1e9, params/1e6))



================================================
FILE: ddpm_exp/compute_pruned_ssim_curve.py
================================================
import pytorch_msssim 
import os
import torch
from PIL import Image
import torchvision

base_folder_name = 'run/prune_ssim_2/0'
folder_name = [os.path.join('run/prune_ssim_2', '{}'.format(k)) for k in range(50, 1000+1, 50)]
n_samples = 32
# test ssim for each folder
folder_ssim = []
for f in folder_name:
    ssim_list = []
    for img_id in range(n_samples):
        img1 = Image.open(os.path.join(base_folder_name, f'{img_id}.png'))
        img2 = Image.open(os.path.join(f, f'{img_id}.png'))
        img1_tensor = torchvision.transforms.ToTensor()(img1)
        img2_tensor = torchvision.transforms.ToTensor()(img2)
        img1_tensor = img1_tensor.unsqueeze(0)
        img2_tensor = img2_tensor.unsqueeze(0)
        ssim = pytorch_msssim.ssim(img1_tensor, img2_tensor, data_range=1.0, size_average=True)
        ssim_list.append(ssim)
    ssim = sum(ssim_list) / len(ssim_list)
    folder_ssim.append(ssim.item())
print(folder_ssim)

================================================
FILE: ddpm_exp/compute_ssim.py
================================================
import torch
import random, os
import argparse
from PIL import Image
import torchvision
import numpy as np
import pytorch_msssim
from utils import UnlabeledImageFolder
from tqdm import tqdm 
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True, nargs='+')
args = parser.parse_args()

# generate radom index
nrow = 16
img_index = random.sample(list(range(50000)), nrow*nrow)
path1 = args.path[0]
path2 = args.path[1]
print(path1, path2)
img_dst1 = UnlabeledImageFolder(path1, transform=torchvision.transforms.ToTensor(), exts=["png"])
img_dst2 = UnlabeledImageFolder(path2, transform=torchvision.transforms.ToTensor(), exts=["png"])
print(len(img_dst1), len(img_dst2))

loader1 = torch.utils.data.DataLoader(
    img_dst1,
    batch_size=100,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)
loader2 = torch.utils.data.DataLoader(
    img_dst2,
    batch_size=100,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)

with torch.no_grad():
    ssim_list = []
    mse_list = []
    for i, (img1, img2) in tqdm(enumerate(zip(loader1, loader2))):
        ssim = pytorch_msssim.ssim(img1.cuda(), img2.cuda(), data_range=1.0, size_average=False)
        ssim_list.append(ssim.cpu())
        mse = torch.nn.functional.mse_loss(img1.cuda(), img2.cuda(), reduction='none').mean(dim=(1,2,3))
        mse_list.append(mse.cpu())

    ssim = torch.cat(ssim_list, dim=0)
    mse = torch.cat(mse_list, dim=0)
    ssim_avg = ssim.mean()
    mse_avg = mse.mean()
    print("path1: {}, path2: {}, ssim: {}, mse: {}".format(path1, path2, ssim_avg, mse_avg))

    

================================================
FILE: ddpm_exp/compute_ssim_vis.py
================================================
import torch
import random, os
import argparse
from PIL import Image
import torchvision
import numpy as np
import pytorch_msssim
from utils import UnlabeledImageFolder
from tqdm import tqdm 
img_ids = [159, 149, 144, 127, 86, 41]
image_folder1 = 'run/sample_v2/bedroom_250k/image_samples/images/0'
image_folder2 = 'run/sample_v2/bedroom_official/image_samples/images/0'
base_img_id = 0
ssim_list = []
for iid in img_ids:
    img1 = Image.open(os.path.join(image_folder1, f'{iid}.png'))
    img2 = Image.open(os.path.join(image_folder2, f'{iid}.png'))
    img1_tensor = torchvision.transforms.ToTensor()(img1).unsqueeze(0)
    img2_tensor = torchvision.transforms.ToTensor()(img2).unsqueeze(0)
    ssim = pytorch_msssim.ssim(img1_tensor, img2_tensor, data_range=1.0, size_average=True)
    ssim_list.append(ssim.item())
print(ssim_list)
    

================================================
FILE: ddpm_exp/configs/bedroom.yml
================================================
data:
    dataset: "LSUN"
    category: "bedroom"
    image_size: 256
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 32

model:
    type: "simple"
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 1, 2, 2, 4, 4]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.0
    var_type: fixedsmall
    ema_rate: 0.999
    ema: True
    resamp_with_conv: True

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 8
    n_epochs: 10000
    n_iters: 5000000
    snapshot_freq: 5000
    validation_freq: 2000

sampling:
    batch_size: 16
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.000002
    beta1: 0.9
    amsgrad: false
    eps: 0.00000001


================================================
FILE: ddpm_exp/configs/celeba.yml
================================================
data:
    dataset: "CELEBA"
    image_size: 64
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4

model:
    type: "simple"
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2, 4]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 96 # 128
    n_epochs: 10000
    n_iters: 5000000
    snapshot_freq: 5000
    validation_freq: 20000

sampling:
    batch_size: 32
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.0002
    beta1: 0.9
    amsgrad: false
    eps: 0.00000001
    grad_clip: 1.0


================================================
FILE: ddpm_exp/configs/church.yml
================================================
data:
    dataset: "LSUN"
    category: "church_outdoor"
    image_size: 256
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 32

model:
    type: "simple"
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 1, 2, 2, 4, 4]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.0
    var_type: fixedsmall
    ema_rate: 0.999
    ema: True
    resamp_with_conv: True

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 8 # 64
    n_epochs: 10000
    n_iters: 5000000
    snapshot_freq: 5000
    validation_freq: 2000

sampling:
    batch_size: 16
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.00002
    beta1: 0.9
    amsgrad: false
    eps: 0.00000001


================================================
FILE: ddpm_exp/configs/cifar10.yml
================================================
data:
    dataset: "CIFAR10"
    image_size: 32
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4

model:
    type: "simple"
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 128
    n_epochs: 10000
    n_iters: 5000000
    snapshot_freq: 5000
    validation_freq: 2000

sampling:
    batch_size: 64
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.0002
    beta1: 0.9
    amsgrad: false
    eps: 0.00000001
    grad_clip: 1.0


================================================
FILE: ddpm_exp/configs/cifar10_pruning.yml
================================================
data:
    dataset: "CIFAR10"
    image_size: 32
    channels: 3
    logit_transform: false
    uniform_dequantization: false
    gaussian_dequantization: false
    random_flip: true
    rescaled: true
    num_workers: 4

model:
    type: "simple"
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult: [1, 2, 2, 2]
    num_res_blocks: 2
    attn_resolutions: [16, ]
    dropout: 0.1
    var_type: fixedlarge
    ema_rate: 0.9999
    ema: True
    resamp_with_conv: True

diffusion:
    beta_schedule: linear
    beta_start: 0.0001
    beta_end: 0.02
    num_diffusion_timesteps: 1000

training:
    batch_size: 128
    n_epochs: 10000
    n_iters: 5000000
    snapshot_freq: 5000
    validation_freq: 2000

sampling:
    batch_size: 64
    last_only: True

optim:
    weight_decay: 0.000
    optimizer: "Adam"
    lr: 0.00002
    beta1: 0.9
    amsgrad: false
    eps: 0.00000001
    grad_clip: 1.0


================================================
FILE: ddpm_exp/datasets/__init__.py
================================================
import os
import torch
import numbers
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torchvision.datasets import CIFAR10
from datasets.celeba import CelebA
from datasets.ffhq import FFHQ
from datasets.lsun import LSUN
from torch.utils.data import Subset
import numpy as np


class Crop(object):
    def __init__(self, x1, x2, y1, y2):
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2

    def __call__(self, img):
        return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)

    def __repr__(self):
        return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
            self.x1, self.x2, self.y1, self.y2
        )


def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )

    if config.data.dataset == "CIFAR10":
        dataset = CIFAR10(
            os.path.join('data', "cifar10"),
            train=True,
            download=True,
            transform=tran_transform,
        )
        test_dataset = CIFAR10(
            os.path.join('data', "cifar10"),
            train=False,
            download=True,
            transform=test_transform,
        )

    elif config.data.dataset == "CELEBA":
        cx = 89
        cy = 121
        x1 = cy - 64
        x2 = cy + 64
        y1 = cx - 64
        y2 = cx + 64
        if config.data.random_flip:
            dataset = CelebA(
                root=os.path.join("data", "celeba"),
                split="train",
                transform=transforms.Compose(
                    [
                        Crop(x1, x2, y1, y2),
                        transforms.Resize(config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]
                ),
                download=False,
            )
        else:
            dataset = CelebA(
                root=os.path.join("data", "celeba"),
                split="train",
                transform=transforms.Compose(
                    [
                        Crop(x1, x2, y1, y2),
                        transforms.Resize(config.data.image_size),
                        transforms.ToTensor(),
                    ]
                ),
                download=False,
            )

        test_dataset = CelebA(
            root=os.path.join("data", "celeba"),
            split="test",
            transform=transforms.Compose(
                [
                    Crop(x1, x2, y1, y2),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor(),
                ]
            ),
            download=True,
        )

    elif config.data.dataset == "LSUN":
        train_folder = "{}_train".format(config.data.category)
        val_folder = "{}_val".format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(
                root=os.path.join("data", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose(
                    [
                        transforms.Resize(config.data.image_size),
                        transforms.CenterCrop(config.data.image_size),
                        transforms.RandomHorizontalFlip(p=0.5),
                        transforms.ToTensor(),
                    ]
                ),
            )
        else:
            dataset = LSUN(
                root=os.path.join("data", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose(
                    [
                        transforms.Resize(config.data.image_size),
                        transforms.CenterCrop(config.data.image_size),
                        transforms.ToTensor(),
                    ]
                ),
            )

        test_dataset = LSUN(
            root=os.path.join("data", "lsun"),
            classes=[val_folder],
            transform=transforms.Compose(
                [
                    transforms.Resize(config.data.image_size),
                    transforms.CenterCrop(config.data.image_size),
                    transforms.ToTensor(),
                ]
            ),
        )

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(
                path=os.path.join("data", "FFHQ"),
                transform=transforms.Compose(
                    [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()]
                ),
                resolution=config.data.image_size,
            )
        else:
            dataset = FFHQ(
                path=os.path.join("data", "FFHQ"),
                transform=transforms.ToTensor(),
                resolution=config.data.image_size,
            )

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = (
            indices[: int(num_items * 0.9)],
            indices[int(num_items * 0.9) :],
        )
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)
    else:
        dataset, test_dataset = None, None

    return dataset, test_dataset


def logit_transform(image, lam=1e-6):
    image = lam + (1 - 2 * lam) * image
    return torch.log(image) - torch.log1p(-image)


def data_transform(config, X):
    if config.data.uniform_dequantization:
        X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
    if config.data.gaussian_dequantization:
        X = X + torch.randn_like(X) * 0.01

    if config.data.rescaled:
        X = 2 * X - 1.0
    elif config.data.logit_transform:
        X = logit_transform(X)

    if hasattr(config, "image_mean"):
        return X - config.image_mean.to(X.device)[None, ...]

    return X


def inverse_data_transform(config, X):
    if hasattr(config, "image_mean"):
        X = X + config.image_mean.to(X.device)[None, ...]

    if config.data.logit_transform:
        X = torch.sigmoid(X)
    elif config.data.rescaled:
        X = (X + 1.0) / 2.0

    return torch.clamp(X, 0.0, 1.0)


================================================
FILE: ddpm_exp/datasets/celeba.py
================================================
import torch
import os
import PIL
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity


class CelebA(VisionDataset):
    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        split (string): One of {'train', 'valid', 'test'}.
            Accordingly dataset is selected.
        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
            The targets represent:
                ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
                ``identity`` (int): label for each person (data points with the same identity are the same person)
                ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
                ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
                    righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
            Defaults to ``attr``.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    base_folder = "Img"
    # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
    # right now.
    file_list = [
        # File ID                         MD5 Hash                            Filename
        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
    ]

    def __init__(self, root,
                 split="train",
                 target_type="attr",
                 transform=None, target_transform=None,
                 download=False):
        import pandas
        super(CelebA, self).__init__(root)
        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]
        self.transform = transform
        self.target_transform = target_transform

        #if download:
        #    self.download()

        #if not self._check_integrity():
        #    raise RuntimeError('Dataset not found or corrupted.' +
        #                       ' You can use download=True to download it')

        self.transform = transform
        self.target_transform = target_transform

        if split.lower() == "train":
            split = 0
        elif split.lower() == "valid":
            split = 1
        elif split.lower() == "test":
            split = 2
        else:
            raise ValueError('Wrong split entered! Please use split="train" '
                             'or split="valid" or split="test"')

        with open(os.path.join(self.root, 'Eval', "list_eval_partition.txt"), "r") as f:
            splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)

        with open(os.path.join(self.root, 'Anno', "identity_CelebA.txt"), "r") as f:
            self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)

        with open(os.path.join(self.root, 'Anno', "list_bbox_celeba.txt"), "r") as f:
            self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)

        with open(os.path.join(self.root, 'Anno', "list_landmarks_align_celeba.txt"), "r") as f:
            self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)

        with open(os.path.join(self.root, 'Anno', "list_attr_celeba.txt"), "r") as f:
            self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)

        mask = (splits[1] == split)
        self.filename = splits[mask].index.values
        self.identity = torch.as_tensor(self.identity[mask].values)
        self.bbox = torch.as_tensor(self.bbox[mask].values)
        self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
        self.attr = torch.as_tensor(self.attr[mask].values)
        self.attr = (self.attr + 1) // 2  # map from {-1, 1} to {0, 1}

    def _check_integrity(self):
        for (_, md5, filename) in self.file_list:
            fpath = os.path.join(self.root, self.base_folder, filename)
            _, ext = os.path.splitext(filename)
            # Allow original archive to be deleted (zip and 7z)
            # Only need the extracted images
            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
                return False

        # Should check a hash of the images
        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))

    def download(self):
        import zipfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        for (file_id, md5, filename) in self.file_list:
            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)

        with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
            f.extractall(os.path.join(self.root, self.base_folder))

    def __getitem__(self, index):
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

        target = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
                raise ValueError("Target type \"{}\" is not recognized.".format(t))
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            X = self.transform(X)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return X, target

    def __len__(self):
        return len(self.attr)

    def extra_repr(self):
        lines = ["Target type: {target_type}", "Split: {split}"]
        return '\n'.join(lines).format(**self.__dict__)


================================================
FILE: ddpm_exp/datasets/ffhq.py
================================================
from io import BytesIO

import lmdb
from PIL import Image
from torch.utils.data import Dataset


class FFHQ(Dataset):
    def __init__(self, path, transform, resolution=8):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

        self.resolution = resolution
        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)
        target = 0

        return img, target

================================================
FILE: ddpm_exp/datasets/lsun.py
================================================
from .vision import VisionDataset
from PIL import Image
import os
import os.path
import io
from collections.abc import Iterable
import pickle
from torchvision.datasets.utils import verify_str_arg, iterable_to_str


class LSUNClass(VisionDataset):
    def __init__(self, root, transform=None, target_transform=None):
        import lmdb

        super(LSUNClass, self).__init__(
            root, transform=transform, target_transform=target_transform
        )

        self.env = lmdb.open(
            root,
            max_readers=1,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )
        with self.env.begin(write=False) as txn:
            self.length = txn.stat()["entries"]
        root_split = root.split("/")
        cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
        if os.path.isfile(cache_file):
            self.keys = pickle.load(open(cache_file, "rb"))
        else:
            with self.env.begin(write=False) as txn:
                self.keys = [key for key, _ in txn.cursor()]
            pickle.dump(self.keys, open(cache_file, "wb"))

    def __getitem__(self, index):
        img, target = None, None
        env = self.env
        with env.begin(write=False) as txn:
            imgbuf = txn.get(self.keys[index])

        buf = io.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return self.length


class LSUN(VisionDataset):
    """
    `LSUN <https://www.yf.io/p/lsun>`_ dataset.

    Args:
        root (string): Root directory for the database files.
        classes (string or list): One of {'train', 'val', 'test'} or a list of
            categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """

    def __init__(self, root, classes="train", transform=None, target_transform=None):
        super(LSUN, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        self.classes = self._verify_classes(classes)

        # for each class, create an LSUNClassDataset
        self.dbs = []
        for c in self.classes:
            self.dbs.append(
                LSUNClass(root=root + "/" + c + "_lmdb", transform=transform)
            )

        self.indices = []
        count = 0
        for db in self.dbs:
            count += len(db)
            self.indices.append(count)

        self.length = count

    def _verify_classes(self, classes):
        categories = [
            "bedroom",
            "bridge",
            "church_outdoor",
            "classroom",
            "conference_room",
            "dining_room",
            "kitchen",
            "living_room",
            "restaurant",
            "tower",
        ]
        dset_opts = ["train", "val", "test"]

        try:
            verify_str_arg(classes, "classes", dset_opts)
            if classes == "test":
                classes = [classes]
            else:
                classes = [c + "_" + classes for c in categories]
        except ValueError:
            if not isinstance(classes, Iterable):
                msg = (
                    "Expected type str or Iterable for argument classes, "
                    "but got type {}."
                )
                raise ValueError(msg.format(type(classes)))

            classes = list(classes)
            msg_fmtstr = (
                "Expected type str for elements in argument classes, "
                "but got type {}."
            )
            for c in classes:
                verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
                c_short = c.split("_")
                category, dset_opt = "_".join(c_short[:-1]), c_short[-1]

                msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
                msg = msg_fmtstr.format(
                    category, "LSUN class", iterable_to_str(categories)
                )
                verify_str_arg(category, valid_values=categories, custom_msg=msg)

                msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
                verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)

        return classes

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target) where target is the index of the target category.
        """
        target = 0
        sub = 0
        for ind in self.indices:
            if index < ind:
                break
            target += 1
            sub = ind

        db = self.dbs[target]
        index = index - sub

        if self.target_transform is not None:
            target = self.target_transform(target)

        img, _ = db[index]
        return img, target

    def __len__(self):
        return self.length

    def extra_repr(self):
        return "Classes: {classes}".format(**self.__dict__)


================================================
FILE: ddpm_exp/datasets/utils.py
================================================
import os
import os.path
import hashlib
import errno
from torch.utils.model_zoo import tqdm


def gen_bar_updater():
    pbar = tqdm(total=None)

    def bar_update(count, block_size, total_size):
        if pbar.total is None and total_size:
            pbar.total = total_size
        progress_bytes = count * block_size
        pbar.update(progress_bytes - pbar.n)

    return bar_update


def check_integrity(fpath, md5=None):
    if md5 is None:
        return True
    if not os.path.isfile(fpath):
        return False
    md5o = hashlib.md5()
    with open(fpath, 'rb') as f:
        # read in 1MB chunks
        for chunk in iter(lambda: f.read(1024 * 1024), b''):
            md5o.update(chunk)
    md5c = md5o.hexdigest()
    if md5c != md5:
        return False
    return True


def makedir_exist_ok(dirpath):
    """
    Python2 support for os.makedirs(.., exist_ok=True)
    """
    try:
        os.makedirs(dirpath)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise


def download_url(url, root, filename=None, md5=None):
    """Download a file from a url and place it in root.

    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the basename of the URL
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
    from six.moves import urllib

    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename)

    makedir_exist_ok(root)

    # downloads file
    if os.path.isfile(fpath) and check_integrity(fpath, md5):
        print('Using downloaded and verified file: ' + fpath)
    else:
        try:
            print('Downloading ' + url + ' to ' + fpath)
            urllib.request.urlretrieve(
                url, fpath,
                reporthook=gen_bar_updater()
            )
        except OSError:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + fpath)
                urllib.request.urlretrieve(
                    url, fpath,
                    reporthook=gen_bar_updater()
                )


def list_dir(root, prefix=False):
    """List all directories at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
    directories = list(
        filter(
            lambda p: os.path.isdir(os.path.join(root, p)),
            os.listdir(root)
        )
    )

    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]

    return directories


def list_files(root, suffix, prefix=False):
    """List all files ending with a suffix at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
    files = list(
        filter(
            lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
            os.listdir(root)
        )
    )

    if prefix is True:
        files = [os.path.join(root, d) for d in files]

    return files


def download_file_from_google_drive(file_id, root, filename=None, md5=None):
    """Download a Google Drive file from  and place it in root.

    Args:
        file_id (str): id of file to be downloaded
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under. If None, use the id of the file.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
    # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
    import requests
    url = "https://docs.google.com/uc?export=download"

    root = os.path.expanduser(root)
    if not filename:
        filename = file_id
    fpath = os.path.join(root, filename)

    makedir_exist_ok(root)

    if os.path.isfile(fpath) and check_integrity(fpath, md5):
        print('Using downloaded and verified file: ' + fpath)
    else:
        session = requests.Session()

        response = session.get(url, params={'id': file_id}, stream=True)
        token = _get_confirm_token(response)

        if token:
            params = {'id': file_id, 'confirm': token}
            response = session.get(url, params=params, stream=True)

        _save_response_content(response, fpath)


def _get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None


def _save_response_content(response, destination, chunk_size=32768):
    with open(destination, "wb") as f:
        pbar = tqdm(total=None)
        progress = 0
        for chunk in response.iter_content(chunk_size):
            if chunk:  # filter out keep-alive new chunks
                f.write(chunk)
                progress += len(chunk)
                pbar.update(progress - pbar.n)
        pbar.close()


================================================
FILE: ddpm_exp/datasets/vision.py
================================================
import os
import torch
import torch.utils.data as data


class VisionDataset(data.Dataset):
    _repr_indent = 4

    def __init__(self, root, transforms=None, transform=None, target_transform=None):
        if isinstance(root, torch._six.string_classes):
            root = os.path.expanduser(root)
        self.root = root
        
        has_transforms = transforms is not None
        has_separate_transform = transform is not None or target_transform is not None
        if has_transforms and has_separate_transform:
            raise ValueError("Only transforms or transform/target_transform can "
                             "be passed as argument")

        # for backwards-compatibility
        self.transform = transform
        self.target_transform = target_transform

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __repr__(self):
        head = "Dataset " + self.__class__.__name__
        body = ["Number of datapoints: {}".format(self.__len__())]
        if self.root is not None:
            body.append("Root location: {}".format(self.root))
        body += self.extra_repr().splitlines()
        if hasattr(self, 'transform') and self.transform is not None:
            body += self._format_transform_repr(self.transform,
                                                "Transforms: ")
        if hasattr(self, 'target_transform') and self.target_transform is not None:
            body += self._format_transform_repr(self.target_transform,
                                                "Target transforms: ")
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return '\n'.join(lines)

    def _format_transform_repr(self, transform, head):
        lines = transform.__repr__().splitlines()
        return (["{}{}".format(head, lines[0])] +
                ["{}{}".format(" " * len(head), line) for line in lines[1:]])

    def extra_repr(self):
        return ""


class StandardTransform(object):
    def __init__(self, transform=None, target_transform=None):
        self.transform = transform
        self.target_transform = target_transform

    def __call__(self, input, target):
        if self.transform is not None:
            input = self.transform(input)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return input, target

    def _format_transform_repr(self, transform, head):
        lines = transform.__repr__().splitlines()
        return (["{}{}".format(head, lines[0])] +
                ["{}{}".format(" " * len(head), line) for line in lines[1:]])

    def __repr__(self):
        body = [self.__class__.__name__]
        if self.transform is not None:
            body += self._format_transform_repr(self.transform,
                                                "Transform: ")
        if self.target_transform is not None:
            body += self._format_transform_repr(self.target_transform,
                                                "Target transform: ")

        return '\n'.join(body)


================================================
FILE: ddpm_exp/draw_ssim_pruned_curve.py
================================================
import matplotlib.pyplot as plt
import numpy as np

plt.style.use('seaborn-whitegrid')

ssim = [0.7881933450698853, 0.8069899082183838, 0.8119480609893799, 0.8162015080451965, 0.8389594554901123, 0.8415904641151428, 0.8398601412773132, 0.8351159691810608, 0.8382353186607361, 0.8380391001701355, 0.8358467221260071, 0.8335589170455933, 0.8339887857437134, 0.8341929316520691, 0.8322316408157349, 0.8351540565490723, 0.8365049958229065, 0.8395034074783325, 0.8369854092597961, 0.8361033797264099]
loss = [1701.1046142578125, 1480.7890625, 1344.899658203125, 1244.06982421875, 1163.198486328125, 1095.482421875, 1037.3287353515625, 986.4912109375, 941.4730224609375, 901.2019653320312, 864.8944091796875, 831.9296875, 801.815185546875, 774.1593017578125, 748.635498046875, 724.9940795898438, 703.0166015625, 682.5120239257812, 663.32568359375, 645.3184204101562, 628.3749389648438, 612.4005126953125, 597.3052978515625, 583.0064697265625, 569.4385986328125, 556.5473022460938, 544.2760009765625, 532.5770874023438, 521.4034423828125, 510.709716796875, 500.46917724609375, 490.6474914550781, 481.222412109375, 472.17437744140625, 463.4827575683594, 455.11700439453125, 447.05352783203125, 439.2783203125, 431.7735595703125, 424.529541015625, 417.5376281738281, 410.7806701660156, 404.24041748046875, 397.9093017578125, 391.771728515625, 385.82086181640625, 380.05987548828125, 374.47900390625, 369.06243896484375, 363.7962646484375, 358.68304443359375, 353.7085266113281, 348.8709716796875, 344.1724853515625, 339.605224609375, 335.15997314453125, 330.8258972167969, 326.5983581542969, 322.47613525390625, 318.45684814453125, 314.5384826660156, 310.71697998046875, 306.98724365234375, 303.3455810546875, 299.7877197265625, 296.310791015625, 292.9114990234375, 289.58966064453125, 286.3433837890625, 283.16998291015625, 280.06707763671875, 277.03326416015625, 274.06695556640625, 271.161865234375, 268.31640625, 265.52734375, 262.79425048828125, 260.1163330078125, 257.49163818359375, 254.91751098632812, 252.3910369873047, 249.91122436523438, 247.47616577148438, 245.088623046875, 242.750732421875, 240.45819091796875, 238.20266723632812, 235.9857177734375, 233.80799865722656, 231.66690063476562, 229.56248474121094, 227.4961395263672, 225.4654541015625, 223.46958923339844, 221.50381469726562, 219.56704711914062, 217.66116333007812, 215.78857421875, 213.94786071777344, 212.13548278808594, 210.34634399414062, 208.5807342529297, 206.84222412109375, 205.13070678710938, 203.44268798828125, 201.78305053710938, 200.1453399658203, 198.532958984375, 196.9427490234375, 195.3751220703125, 193.82810974121094, 192.3029022216797, 190.79945373535156, 189.3192596435547, 187.86029052734375, 186.42015075683594, 184.99755859375, 183.5935516357422, 182.20855712890625, 180.8433380126953, 179.49468994140625, 178.16262817382812, 176.84832763671875, 175.5514678955078, 174.27105712890625, 173.00845336914062, 171.76220703125, 170.53294372558594, 169.3196563720703, 168.12110900878906, 166.9364471435547, 165.76898193359375, 164.61373901367188, 163.47064208984375, 162.33859252929688, 161.22225952148438, 160.12184143066406, 159.03469848632812, 157.95745849609375, 156.89071655273438, 155.836181640625, 154.79522705078125, 153.76693725585938, 152.75314331054688, 151.74832153320312, 150.75482177734375, 149.77114868164062, 148.79794311523438, 147.83566284179688, 146.88336181640625, 145.94140625, 145.0076141357422, 144.08383178710938, 143.16946411132812, 142.2662353515625, 141.37387084960938, 140.49444580078125, 139.6236572265625, 138.75778198242188, 137.9005889892578, 137.05108642578125, 136.20909118652344, 135.3750457763672, 134.55081176757812, 133.73622131347656, 132.92990112304688, 132.1318359375, 131.3417510986328, 130.55929565429688, 129.7850341796875, 129.01817321777344, 128.2555694580078, 127.50141906738281, 126.75614929199219, 126.01831817626953, 125.28939819335938, 124.56617736816406, 123.84973907470703, 123.13714599609375, 122.43183135986328, 121.73556518554688, 121.05023956298828, 120.36729431152344, 119.6861801147461, 119.00979614257812, 118.34149932861328, 117.68106079101562, 117.0245361328125, 116.37326049804688, 115.72502899169922, 115.0794677734375, 114.43995666503906, 113.80744934082031, 113.185302734375, 112.56655883789062, 111.9501724243164, 111.33888244628906, 110.731201171875, 110.12779998779297, 109.52902221679688, 108.93549346923828, 108.3429183959961, 107.75627899169922, 107.17346954345703, 106.59703063964844, 106.02497863769531, 105.45870971679688, 104.89581298828125, 104.33368682861328, 103.77556610107422, 103.22441101074219, 102.67940521240234, 102.13814544677734, 101.595947265625, 101.05874633789062, 100.52738189697266, 99.99861145019531, 99.47378540039062, 98.95279693603516, 98.43711853027344, 97.927490234375, 97.41922760009766, 96.91082763671875, 96.40644836425781, 95.91044616699219, 95.42032623291016, 94.93365478515625, 94.44966125488281, 93.96620178222656, 93.48726654052734, 93.01268005371094, 92.54640197753906, 92.08195495605469, 91.62220764160156, 91.16517639160156, 90.71146392822266, 90.26289367675781, 89.81585693359375, 89.37476348876953, 88.93356323242188, 88.49430084228516, 88.05615234375, 87.62488555908203, 87.20310974121094, 86.78509521484375, 86.36715698242188, 85.94837188720703, 85.5322036743164, 85.119384765625, 84.70726013183594, 84.29851531982422, 83.8918685913086, 83.4896469116211, 83.08975219726562, 82.6915283203125, 82.2962417602539, 81.90049743652344, 81.50473022460938, 81.11053466796875, 80.71994018554688, 80.33277130126953, 79.95138549804688, 79.5705337524414, 79.19496154785156, 78.82186889648438, 78.44967651367188, 78.078125, 77.70683288574219, 77.33111572265625, 76.96113586425781, 76.59521484375, 76.23463439941406, 75.8738784790039, 75.5166015625, 75.15583038330078, 74.7983627319336, 74.4422378540039, 74.08897399902344, 73.74169921875, 73.39446258544922, 73.05160522460938, 72.7100830078125, 72.37609100341797, 72.03876495361328, 71.70317077636719, 71.36734008789062, 71.03208923339844, 70.6976318359375, 70.36759185791016, 70.03667449951172, 69.70942687988281, 69.38262176513672, 69.05623626708984, 68.73023223876953, 68.40640258789062, 68.08369445800781, 67.76331329345703, 67.44386291503906, 67.12708282470703, 66.80943298339844, 66.49469757080078, 66.18429565429688, 65.87628173828125, 65.56700897216797, 65.25809478759766, 64.94975280761719, 64.64031219482422, 64.33030700683594, 64.01728057861328, 63.7086181640625, 63.404884338378906, 63.10600662231445, 62.80165481567383, 62.49952697753906, 62.1977653503418, 61.896236419677734, 61.5944709777832, 61.29471969604492, 60.99738311767578, 60.701576232910156, 60.40677261352539, 60.11320877075195, 59.82258224487305, 59.53261947631836, 59.24340057373047, 58.95545959472656, 58.671363830566406, 58.38852310180664, 58.108306884765625, 57.82665252685547, 57.54595947265625, 57.266910552978516, 56.98750686645508, 56.709312438964844, 56.431488037109375, 56.154640197753906, 55.880706787109375, 55.60820007324219, 55.33389663696289, 55.061912536621094, 54.790802001953125, 54.52444076538086, 54.25933837890625, 53.99445724487305, 53.730342864990234, 53.46611022949219, 53.19966125488281, 52.93572998046875, 52.67394256591797, 52.416324615478516, 52.15863037109375, 51.9014892578125, 51.64102554321289, 51.380699157714844, 51.123252868652344, 50.869956970214844, 50.619815826416016, 50.36848068237305, 50.117496490478516, 49.864952087402344, 49.614585876464844, 49.36606216430664, 49.11790084838867, 48.872718811035156, 48.6287841796875, 48.383766174316406, 48.13976287841797, 47.89812469482422, 47.65973663330078, 47.42310333251953, 47.187564849853516, 46.95360565185547, 46.718772888183594, 46.485862731933594, 46.25577926635742, 46.027374267578125, 45.79834747314453, 45.57139587402344, 45.344879150390625, 45.12104797363281, 44.89805603027344, 44.67631530761719, 44.45355987548828, 44.230438232421875, 44.00713348388672, 43.78421401977539, 43.56409454345703, 43.346736907958984, 43.13029098510742, 42.91291809082031, 42.694793701171875, 42.47801971435547, 42.26348876953125, 42.0494384765625, 41.83710479736328, 41.625648498535156, 41.414955139160156, 41.20560073852539, 40.99818420410156, 40.791481018066406, 40.585731506347656, 40.380332946777344, 40.17467498779297, 39.96947479248047, 39.76425552368164, 39.55950164794922, 39.355892181396484, 39.152591705322266, 38.949493408203125, 38.745704650878906, 38.54241180419922, 38.34136199951172, 38.14094924926758, 37.94009017944336, 37.73867416381836, 37.537044525146484, 37.336360931396484, 37.13530731201172, 36.934104919433594, 36.73320007324219, 36.53413391113281, 36.33588409423828, 36.13719940185547, 35.937522888183594, 35.739013671875, 35.542327880859375, 35.34554672241211, 35.149192810058594, 34.954002380371094, 34.75892639160156, 34.56444549560547, 34.37016296386719, 34.175926208496094, 33.980865478515625, 33.78504943847656, 33.58888244628906, 33.393436431884766, 33.19927215576172, 33.0058708190918, 32.81242752075195, 32.61833953857422, 32.423622131347656, 32.22987365722656, 32.03801345825195, 31.847660064697266, 31.658077239990234, 31.468374252319336, 31.278139114379883, 31.088520050048828, 30.89966583251953, 30.711360931396484, 30.523611068725586, 30.336217880249023, 30.14923095703125, 29.963003158569336, 29.778030395507812, 29.593975067138672, 29.410850524902344, 29.228641510009766, 29.04698371887207, 28.86606216430664, 28.686256408691406, 28.507423400878906, 28.328943252563477, 28.150875091552734, 27.97317123413086, 27.796123504638672, 27.619800567626953, 27.44475555419922, 27.27039337158203, 27.096721649169922, 26.923952102661133, 26.751808166503906, 26.58050537109375, 26.410259246826172, 26.241487503051758, 26.073375701904297, 25.90522003173828, 25.737794876098633, 25.571483612060547, 25.406145095825195, 25.241474151611328, 25.077468872070312, 24.914419174194336, 24.752695083618164, 24.592121124267578, 24.43234634399414, 24.273231506347656, 24.114547729492188, 23.955957412719727, 23.798091888427734, 23.641109466552734, 23.484966278076172, 23.329912185668945, 23.17569923400879, 23.02271842956543, 22.87088966369629, 22.720455169677734, 22.570812225341797, 22.421768188476562, 22.27305793762207, 22.12588119506836, 21.979734420776367, 21.83547592163086, 21.692113876342773, 21.54950714111328, 21.406421661376953, 21.263275146484375, 21.122180938720703, 20.982616424560547, 20.842681884765625, 20.70368766784668, 20.565185546875, 20.427291870117188, 20.289962768554688, 20.153470993041992, 20.017803192138672, 19.883617401123047, 19.750282287597656, 19.616554260253906, 19.483001708984375, 19.3499755859375, 19.217700958251953, 19.086288452148438, 18.956266403198242, 18.826961517333984, 18.699195861816406, 18.57220458984375, 18.44577407836914, 18.320362091064453, 18.19457244873047, 18.069015502929688, 17.943695068359375, 17.81871795654297, 17.69390106201172, 17.569774627685547, 17.445335388183594, 17.32052230834961, 17.196054458618164, 17.072261810302734, 16.949085235595703, 16.826692581176758, 16.704326629638672, 16.581392288208008, 16.458843231201172, 16.336633682250977, 16.215139389038086, 16.094675064086914, 15.97494888305664, 15.855484962463379, 15.736421585083008, 15.61740493774414, 15.499762535095215, 15.383194923400879, 15.267399787902832, 15.1524658203125, 15.037047386169434, 14.921875, 14.807332992553711, 14.6937837600708, 14.581942558288574, 14.471169471740723, 14.359855651855469, 14.247969627380371, 14.136343002319336, 14.0252103805542, 13.91417121887207, 13.804189682006836, 13.695024490356445, 13.586652755737305, 13.478588104248047, 13.370796203613281, 13.263662338256836, 13.156829833984375, 13.050079345703125, 12.943574905395508, 12.8375244140625, 12.732397079467773, 12.627817153930664, 12.523782730102539, 12.420042037963867, 12.316635131835938, 12.213995933532715, 12.112537384033203, 12.011775970458984, 11.911060333251953, 11.81020736694336, 11.709442138671875, 11.609159469604492, 11.510124206542969, 11.41226577758789, 11.314695358276367, 11.217529296875, 11.120745658874512, 11.024354934692383, 10.928537368774414, 10.833683013916016, 10.739753723144531, 10.646350860595703, 10.553449630737305, 10.460601806640625, 10.368167877197266, 10.275918960571289, 10.184614181518555, 10.094273567199707, 10.00442123413086, 9.915353775024414, 9.826492309570312, 9.737730979919434, 9.649423599243164, 9.56181526184082, 9.474851608276367, 9.3887939453125, 9.303421020507812, 9.218555450439453, 9.134313583374023, 9.050726890563965, 8.967226028442383, 8.884288787841797, 8.80159854888916, 8.71987533569336, 8.638511657714844, 8.55799674987793, 8.477949142456055, 8.398301124572754, 8.319267272949219, 8.240785598754883, 8.16270637512207, 8.084941864013672, 8.00758171081543, 7.930901050567627, 7.854917526245117, 7.779313087463379, 7.70433235168457, 7.629762649536133, 7.555689811706543, 7.482339859008789, 7.409557342529297, 7.336893081665039, 7.26435661315918, 7.192422866821289, 7.1210832595825195, 7.050585746765137, 6.980339527130127, 6.910588264465332, 6.841214179992676, 6.772533893585205, 6.7045183181762695, 6.637387275695801, 6.570642948150635, 6.504342079162598, 6.438445568084717, 6.372898101806641, 6.307552337646484, 6.242652416229248, 6.1783342361450195, 6.114615440368652, 6.05164909362793, 5.989071369171143, 5.926817893981934, 5.864975929260254, 5.8035430908203125, 5.742734909057617, 5.6824235916137695, 5.622620105743408, 5.56337308883667, 5.504632472991943, 5.44635009765625, 5.388323783874512, 5.330789089202881, 5.273573398590088, 5.216763019561768, 5.160672187805176, 5.105111122131348, 5.0502824783325195, 4.995979309082031, 4.941986083984375, 4.887808799743652, 4.834039688110352, 4.780519485473633, 4.727843284606934, 4.675847053527832, 4.624444007873535, 4.573704719543457, 4.523074150085449, 4.472465515136719, 4.4224443435668945, 4.372983932495117, 4.324192047119141, 4.275915145874023, 4.227880954742432, 4.180074691772461, 4.132648944854736, 4.085781097412109, 4.03977632522583, 3.993844747543335, 3.948244094848633, 3.9029383659362793, 3.8578546047210693, 3.8133997917175293, 3.7693405151367188, 3.7258453369140625, 3.682644844055176, 3.6398329734802246, 3.5972371101379395, 3.555053234100342, 3.51322078704834, 3.47196102142334, 3.4310460090637207, 3.3906357288360596, 3.350480079650879, 3.3109285831451416, 3.27160382270813, 3.2327582836151123, 3.1944308280944824, 3.156402826309204, 3.118509292602539, 3.081109046936035, 3.0437707901000977, 3.006779670715332, 2.970116138458252, 2.9337751865386963, 2.8978099822998047, 2.862208366394043, 2.827075958251953, 2.7925100326538086, 2.75809383392334, 2.7239840030670166, 2.690075159072876, 2.6565277576446533, 2.6233577728271484, 2.59077787399292, 2.5584311485290527, 2.52634859085083, 2.494575262069702, 2.4631528854370117, 2.4322350025177, 2.401658535003662, 2.3714113235473633, 2.341179370880127, 2.311309576034546, 2.281765937805176, 2.2527246475219727, 2.223891496658325, 2.19508695602417, 2.166395664215088, 2.1379075050354004, 2.109847068786621, 2.0825068950653076, 2.055701732635498, 2.0291335582733154, 2.0026559829711914, 1.9766024351119995, 1.9506968259811401, 1.9250354766845703, 1.8999783992767334, 1.8748676776885986, 1.8499674797058105, 1.825405478477478, 1.80095636844635, 1.7769408226013184, 1.7534890174865723, 1.7300043106079102, 1.7066116333007812, 1.6833750009536743, 1.6603283882141113, 1.637768030166626, 1.615763783454895, 1.5940053462982178, 1.5724085569381714, 1.5510432720184326, 1.5298420190811157, 1.5087008476257324, 1.4878756999969482, 1.46720552444458, 1.4469165802001953, 1.42703378200531, 1.4072201251983643, 1.3874483108520508, 1.368030309677124, 1.3487962484359741, 1.3297483921051025, 1.3111402988433838, 1.2925734519958496, 1.2742488384246826, 1.2563354969024658, 1.239056944847107, 1.221704125404358, 1.2041783332824707, 1.1868020296096802, 1.169940710067749, 1.1533498764038086, 1.137148380279541, 1.1206257343292236, 1.1041970252990723, 1.0881215333938599, 1.0724332332611084, 1.0570216178894043, 1.0418777465820312, 1.027139663696289, 1.01233971118927, 0.9975317716598511, 0.982831597328186, 0.9682676196098328, 0.9541558027267456, 0.9400227665901184, 0.9261940717697144, 0.912714421749115, 0.8991694450378418, 0.8859115839004517, 0.8725903630256653, 0.8595502376556396, 0.8467142581939697, 0.8339176774024963, 0.8214421272277832, 0.8089535236358643, 0.7969492673873901, 0.7850332856178284, 0.773577868938446, 0.7619916796684265, 0.750514566898346, 0.7390782833099365, 0.727920651435852, 0.7166476845741272, 0.7057903409004211, 0.6950562596321106, 0.6844556331634521, 0.6740690469741821, 0.6636630296707153, 0.6534810662269592, 0.6433228850364685, 0.633602499961853, 0.6237667202949524, 0.6142684817314148, 0.604796290397644, 0.5954349040985107, 0.5865603685379028, 0.5772913694381714, 0.5682709217071533, 0.5592767000198364, 0.5505585670471191, 0.542014479637146, 0.5336143374443054, 0.5253323316574097, 0.5171076059341431, 0.5091003179550171, 0.5012395977973938, 0.493182897567749, 0.48534587025642395, 0.47781723737716675, 0.4703660011291504, 0.46325814723968506, 0.4561285972595215, 0.4492647051811218, 0.44225677847862244, 0.4352635145187378, 0.4283895790576935, 0.42159581184387207, 0.4147633910179138, 0.40812188386917114, 0.4014328718185425, 0.39500337839126587, 0.3887244760990143, 0.3823932707309723, 0.3762394189834595, 0.37033089995384216, 0.3649292290210724, 0.35929298400878906, 0.3535959720611572, 0.34809982776641846, 0.342424601316452, 0.33691826462745667, 0.3312578797340393, 0.32602420449256897, 0.3210577368736267, 0.3160027861595154, 0.31118467450141907, 0.30630141496658325, 0.3016233444213867, 0.2967444062232971, 0.2918931841850281, 0.2870868444442749, 0.28218063712120056, 0.27778011560440063, 0.2733995318412781, 0.26912352442741394, 0.2651045024394989, 0.2612626552581787, 0.257208913564682, 0.2527666687965393, 0.24844947457313538, 0.2442183792591095, 0.2404608428478241, 0.23674535751342773, 0.2332439422607422, 0.22945661842823029, 0.22556596994400024, 0.22160732746124268, 0.21786770224571228, 0.21428701281547546, 0.21095798909664154, 0.20757737755775452, 0.20441585779190063, 0.20126710832118988, 0.19809749722480774, 0.19500720500946045, 0.19189956784248352, 0.18880324065685272, 0.18582747876644135, 0.18278327584266663, 0.18014571070671082, 0.17744965851306915, 0.1748645305633545, 0.1720936894416809, 0.16946980357170105, 0.16655337810516357, 0.16428059339523315, 0.16178739070892334, 0.15912079811096191, 0.1564910113811493, 0.15409769117832184, 0.1519279032945633, 0.14959266781806946, 0.14729416370391846, 0.14519476890563965, 0.14297731220722198, 0.14067870378494263, 0.1384662240743637, 0.13619963824748993, 0.13372863829135895, 0.13164782524108887, 0.12953300774097443, 0.12771821022033691, 0.12615008652210236, 0.12415669858455658, 0.1224883422255516, 0.12066398561000824, 0.11871618032455444, 0.11666613817214966, 0.11477735638618469, 0.11323611438274384, 0.11176219582557678, 0.11029437184333801, 0.10888860374689102, 0.10724224895238876, 0.10557863116264343, 0.10384707897901535, 0.10236769914627075, 0.1006280779838562, 0.09903483092784882, 0.09761733561754227, 0.09611119329929352, 0.09474831074476242, 0.09324538707733154, 0.0919327437877655, 0.09046762436628342, 0.08917433023452759, 0.08783926069736481, 0.08677000552415848, 0.08565567433834076, 0.08473670482635498, 0.08394858241081238, 0.08294089138507843, 0.08191481232643127, 0.08059291541576385, 0.07923152297735214, 0.07800912857055664, 0.07714410126209259, 0.07613521814346313, 0.07499660551548004, 0.0738428384065628, 0.0728812888264656, 0.07185603678226471, 0.07075358927249908, 0.06998489797115326, 0.06904087215662003, 0.06817013025283813, 0.0675143226981163, 0.06662875413894653, 0.06585747003555298, 0.06498762965202332, 0.06416307389736176, 0.06331884860992432, 0.06252430379390717, 0.06164319068193436, 0.0608375146985054, 0.06004999577999115, 0.05958651006221771, 0.05899893864989281, 0.05812441185116768, 0.05736687779426575, 0.05654726177453995, 0.05582256615161896, 0.055398836731910706, 0.05473468452692032, 0.05418632924556732, 0.053625207394361496, 0.05313240364193916, 0.05263800173997879, 0.05228545516729355]
loss = loss / np.max(loss)
stage = list(range(50, 1000+1, 50))

# Set the font size
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 23

# Define custom colors in a dark color scheme
color_blue = '#002147'
color_red = '#8C1515'
background_color = '#F0F0F0'

# Set the figure size
plt.figure(figsize=(10, 4))

# Set the background color
plt.rcParams['axes.facecolor'] = background_color

# Plot SSIM with markers
ax1 = plt.gca()
ax1.set_ylabel('SSIM')
ax1.plot(stage, ssim, marker='o', linestyle='-', label='SSIM', color=color_blue)

# Plot Loss with markers
ax2 = plt.twinx()
ax2.set_ylabel('Relative Loss')
ax2.plot(np.arange(len(loss)), loss, label='Relative Loss', color=color_red)

# Set x-axis label
ax1.set_xlabel('Steps')

# Add vertical dashed line and text
ax1.axvline(x=300, color='#777777aa', linestyle='dashed')
ax2.text(300 + 10, 0.35, 'Threshold $\mathcal{T}$', rotation=90, color='gray', fontsize=12)

# Adjust the alignment of twin axes
ax1.spines['left'].set_color(color_blue)
ax1.spines['right'].set_visible(False)
ax1.yaxis.label.set_color(color_blue)
ax1.tick_params(axis='y', colors=color_blue)
ax2.spines['right'].set_color(color_red)
ax2.spines['left'].set_visible(False)
ax2.yaxis.label.set_color(color_red)
ax2.tick_params(axis='y', colors=color_red)

# Set the y-axis limits for Loss
ax2.set_ylim([0, np.max(loss)])

# Show legend in one box
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='lower right')

# Set the title

# Add grid
plt.grid(color='white', linestyle='-', linewidth=0.5)

# Set the layout and padding
plt.tight_layout(pad=1.0)

# Save the figure
plt.title("CIFAR-10")
plt.savefig('run/ssim_loss.pdf', dpi=300)
plt.savefig('run/ssim_loss.png', dpi=300)

plt.close()










ssim = [0.5799465179443359, 0.6136730909347534, 0.6535662412643433, 0.6940350532531738, 0.7176114916801453, 0.7297082543373108, 0.7317281365394592, 0.7415740489959717, 0.731893002986908, 0.7338791489601135, 0.7469093799591064, 0.7643880248069763, 0.7528172135353088, 0.7652474045753479, 0.7676505446434021, 0.7659737467765808, 0.7663692831993103, 0.7921077609062195, 0.7951106429100037, 0.7900822162628174]
loss = [4539.2880859375, 3768.095947265625, 3357.234375, 3075.648681640625, 2858.8349609375, 2681.1640625, 2529.66748046875, 2396.486083984375, 2277.68798828125, 2170.572265625, 2073.0498046875, 1983.8221435546875, 1901.969482421875, 1826.51416015625, 1756.9361572265625, 1692.67529296875, 1632.9080810546875, 1577.036865234375, 1524.709228515625, 1475.60205078125, 1429.496826171875, 1386.111572265625, 1345.2410888671875, 1306.699462890625, 1270.2772216796875, 1235.816650390625, 1203.143310546875, 1172.09619140625, 1142.5673828125, 1114.4703369140625, 1087.7288818359375, 1062.2257080078125, 1037.8934326171875, 1014.63818359375, 992.4036865234375, 971.1495361328125, 950.8143310546875, 931.31298828125, 912.5927734375, 894.5926513671875, 877.2578125, 860.58203125, 844.5435180664062, 829.0625, 814.1162109375, 799.66796875, 785.71142578125, 772.231201171875, 759.2171630859375, 746.6537475585938, 734.4761962890625, 722.6871948242188, 711.263916015625, 700.1922607421875, 689.4476318359375, 679.0269775390625, 668.9332885742188, 659.1286010742188, 649.6156005859375, 640.3824462890625, 631.3946533203125, 622.6455078125, 614.136962890625, 605.8540649414062, 597.7898559570312, 589.9388427734375, 582.287109375, 574.8184814453125, 567.5428466796875, 560.46044921875, 553.5679931640625, 546.8153076171875, 540.208251953125, 533.760986328125, 527.4842529296875, 521.3709716796875, 515.41015625, 509.5645751953125, 503.8065490722656, 498.16217041015625, 492.656494140625, 487.27862548828125, 482.0224914550781, 476.8771667480469, 471.83258056640625, 466.8781433105469, 462.02825927734375, 457.2860107421875, 452.6307373046875, 448.0614013671875, 443.5903015136719, 439.2196044921875, 434.9393310546875, 430.72430419921875, 426.5743713378906, 422.4966735839844, 418.49322509765625, 414.55938720703125, 410.68951416015625, 406.8865966796875, 403.1524658203125, 399.4714050292969, 395.8497619628906, 392.2903137207031, 388.796630859375, 385.36053466796875, 381.9772033691406, 378.6322326660156, 375.3372497558594, 372.1002197265625, 368.9242858886719, 365.79498291015625, 362.7067565917969, 359.657470703125, 356.65362548828125, 353.6961669921875, 350.78558349609375, 347.9106750488281, 345.0701904296875, 342.27056884765625, 339.5118713378906, 336.7920227050781, 334.1104736328125, 331.4671630859375, 328.8617858886719, 326.29071044921875, 323.758056640625, 321.26202392578125, 318.79345703125, 316.34747314453125, 313.92901611328125, 311.5457458496094, 309.1994934082031, 306.8837890625, 304.5982360839844, 302.3414611816406, 300.10614013671875, 297.8866271972656, 295.6922607421875, 293.5301513671875, 291.3970031738281, 289.29144287109375, 287.2158203125, 285.16796875, 283.1424560546875, 281.13409423828125, 279.1377258300781, 277.15594482421875, 275.20147705078125, 273.274658203125, 271.3721618652344, 269.4991455078125, 267.65350341796875, 265.8316650390625, 264.02655029296875, 262.23492431640625, 260.45770263671875, 258.69610595703125, 256.95526123046875, 255.234130859375, 253.54019165039062, 251.86520385742188, 250.21116638183594, 248.5689239501953, 246.93984985351562, 245.31585693359375, 243.70863342285156, 242.11697387695312, 240.54721069335938, 238.99844360351562, 237.47918701171875, 235.97952270507812, 234.4829559326172, 232.99127197265625, 231.49490356445312, 230.009033203125, 228.55003356933594, 227.11834716796875, 225.71218872070312, 224.3217010498047, 222.94598388671875, 221.5821990966797, 220.22264099121094, 218.86444091796875, 217.51229858398438, 216.1671600341797, 214.83935546875, 213.5290985107422, 212.23281860351562, 210.94717407226562, 209.6590118408203, 208.39019775390625, 207.14077758789062, 205.90513610839844, 204.6878662109375, 203.48330688476562, 202.28472900390625, 201.0923309326172, 199.90463256835938, 198.7206268310547, 197.548828125, 196.3909912109375, 195.24566650390625, 194.1101531982422, 192.98435974121094, 191.87210083007812, 190.7747344970703, 189.68112182617188, 188.5771484375, 187.48060607910156, 186.40255737304688, 185.33895874023438, 184.28480529785156, 183.2340545654297, 182.18869018554688, 181.15188598632812, 180.12258911132812, 179.10223388671875, 178.09197998046875, 177.0926513671875, 176.10470581054688, 175.13047790527344, 174.16094970703125, 173.19235229492188, 172.22509765625, 171.26443481445312, 170.31271362304688, 169.37289428710938, 168.43927001953125, 167.5128173828125, 166.6007080078125, 165.69393920898438, 164.78707885742188, 163.87860107421875, 162.9747772216797, 162.08349609375, 161.205810546875, 160.3376922607422, 159.48043823242188, 158.62545776367188, 157.76852416992188, 156.90916442871094, 156.06423950195312, 155.22775268554688, 154.40029907226562, 153.5887451171875, 152.77359008789062, 151.96192932128906, 151.1425018310547, 150.33209228515625, 149.52468872070312, 148.72679138183594, 147.93780517578125, 147.1558837890625, 146.38275146484375, 145.6148681640625, 144.84906005859375, 144.08462524414062, 143.3209991455078, 142.56076049804688, 141.80752563476562, 141.05966186523438, 140.32247924804688, 139.5969696044922, 138.8789520263672, 138.15431213378906, 137.4319305419922, 136.71456909179688, 136.001953125, 135.29641723632812, 134.59725952148438, 133.90882873535156, 133.21835327148438, 132.53102111816406, 131.8437042236328, 131.16131591796875, 130.48312377929688, 129.80667114257812, 129.1376953125, 128.47406005859375, 127.81607055664062, 127.16371154785156, 126.51216125488281, 125.86033630371094, 125.21190643310547, 124.56687927246094, 123.92790985107422, 123.29476165771484, 122.66639709472656, 122.04109954833984, 121.41598510742188, 120.79107666015625, 120.17152404785156, 119.55781555175781, 118.95069885253906, 118.35137176513672, 117.76033020019531, 117.16847229003906, 116.5713119506836, 115.96839904785156, 115.3705825805664, 114.77892303466797, 114.19495391845703, 113.61700439453125, 113.04344177246094, 112.47366333007812, 111.90603637695312, 111.33882141113281, 110.76946258544922, 110.20514678955078, 109.64720916748047, 109.09403228759766, 108.5472183227539, 108.00634002685547, 107.47209167480469, 106.93167877197266, 106.3891830444336, 105.84819030761719, 105.31497192382812, 104.78697967529297, 104.2626953125, 103.74164581298828, 103.2225341796875, 102.70263671875, 102.18333435058594, 101.66871643066406, 101.15846252441406, 100.65023803710938, 100.14640808105469, 99.64842987060547, 99.1537094116211, 98.66120910644531, 98.17019653320312, 97.68133544921875, 97.19447326660156, 96.70794677734375, 96.22498321533203, 95.74699401855469, 95.27214050292969, 94.80259704589844, 94.33360290527344, 93.86566162109375, 93.39513397216797, 92.92683410644531, 92.46549224853516, 92.01234436035156, 91.56328582763672, 91.11687469482422, 90.66963195800781, 90.21475982666016, 89.76103210449219, 89.30477905273438, 88.85597229003906, 88.41807556152344, 87.98734283447266, 87.558837890625, 87.13265228271484, 86.70266723632812, 86.26994323730469, 85.84093475341797, 85.41881561279297, 85.00090026855469, 84.590087890625, 84.18545532226562, 83.78675842285156, 83.37939453125, 82.96934509277344, 82.55995178222656, 82.15340423583984, 81.7559814453125, 81.36235046386719, 80.9716796875, 80.58477020263672, 80.20170593261719, 79.81788635253906, 79.4328842163086, 79.04986572265625, 78.66758728027344, 78.28589630126953, 77.90595245361328, 77.52804565429688, 77.15101623535156, 76.77655792236328, 76.40097045898438, 76.02364349365234, 75.64501953125, 75.26689147949219, 74.8895263671875, 74.51441192626953, 74.14277648925781, 73.77396392822266, 73.40673828125, 73.03805541992188, 72.6688232421875, 72.29788208007812, 71.93028259277344, 71.56787109375, 71.20892333984375, 70.85118103027344, 70.49273681640625, 70.13478088378906, 69.77824401855469, 69.4261245727539, 69.07501220703125, 68.72384643554688, 68.37470245361328, 68.0278091430664, 67.68260192871094, 67.33564758300781, 66.98927307128906, 66.64806365966797, 66.31043243408203, 65.97486877441406, 65.63951110839844, 65.3048095703125, 64.97119140625, 64.63858795166016, 64.30635070800781, 63.974857330322266, 63.64590072631836, 63.31915283203125, 62.994956970214844, 62.67194747924805, 62.350990295410156, 62.02912521362305, 61.705020904541016, 61.37890625, 61.056907653808594, 60.73951721191406, 60.42740249633789, 60.118309020996094, 59.811317443847656, 59.502479553222656, 59.189208984375, 58.87731170654297, 58.56908416748047, 58.265113830566406, 57.964420318603516, 57.6660041809082, 57.368896484375, 57.0717658996582, 56.77558898925781, 56.478187561035156, 56.17683029174805, 55.87291717529297, 55.571327209472656, 55.27485656738281, 54.98467254638672, 54.699951171875, 54.41563034057617, 54.129852294921875, 53.84320068359375, 53.55724334716797, 53.274314880371094, 52.993202209472656, 52.71284484863281, 52.43351745605469, 52.15589904785156, 51.879058837890625, 51.6051025390625, 51.334014892578125, 51.0655403137207, 50.798011779785156, 50.53080749511719, 50.26268005371094, 49.99579620361328, 49.73133850097656, 49.46936798095703, 49.208396911621094, 48.94771957397461, 48.687599182128906, 48.42808532714844, 48.169952392578125, 47.91440200805664, 47.6610107421875, 47.409912109375, 47.1600341796875, 46.91320037841797, 46.66697692871094, 46.419395446777344, 46.16619873046875, 45.91246032714844, 45.66108322143555, 45.413063049316406, 45.16651153564453, 44.921051025390625, 44.67573547363281, 44.432029724121094, 44.189178466796875, 43.94789123535156, 43.70808410644531, 43.470550537109375, 43.234169006347656, 42.99854278564453, 42.762184143066406, 42.523956298828125, 42.2872314453125, 42.054325103759766, 41.82463073730469, 41.59766387939453, 41.369163513183594, 41.13751983642578, 40.90327453613281, 40.668785095214844, 40.43696212768555, 40.209022521972656, 39.983795166015625, 39.75868225097656, 39.532135009765625, 39.304534912109375, 39.078147888183594, 38.85414123535156, 38.63133239746094, 38.40913391113281, 38.18644714355469, 37.964088439941406, 37.74298858642578, 37.524375915527344, 37.30720520019531, 37.09025955200195, 36.87240219116211, 36.65234375, 36.431217193603516, 36.21006774902344, 35.992942810058594, 35.77734375, 35.56315612792969, 35.34864044189453, 35.13450622558594, 34.920257568359375, 34.706077575683594, 34.49102783203125, 34.27655029296875, 34.06257247924805, 33.84934616088867, 33.63843536376953, 33.429710388183594, 33.22266387939453, 33.014068603515625, 32.805091857910156, 32.597267150878906, 32.39238739013672, 32.19000244140625, 31.989036560058594, 31.787109375, 31.58185577392578, 31.37429428100586, 31.16775894165039, 30.96398162841797, 30.763967514038086, 30.566076278686523, 30.369342803955078, 30.171037673950195, 29.970752716064453, 29.76913070678711, 29.568056106567383, 29.370765686035156, 29.175628662109375, 28.980947494506836, 28.78591537475586, 28.58932876586914, 28.39220428466797, 28.19733428955078, 28.00490951538086, 27.81519317626953, 27.626718521118164, 27.439435958862305, 27.252986907958984, 27.067100524902344, 26.88239288330078, 26.699377059936523, 26.517051696777344, 26.334407806396484, 26.151626586914062, 25.969593048095703, 25.789051055908203, 25.610992431640625, 25.433944702148438, 25.257705688476562, 25.081737518310547, 24.905364990234375, 24.730724334716797, 24.556011199951172, 24.38326644897461, 24.210472106933594, 24.03833770751953, 23.867185592651367, 23.69612693786621, 23.526260375976562, 23.35738754272461, 23.189224243164062, 23.020156860351562, 22.851537704467773, 22.68604850769043, 22.52411651611328, 22.364612579345703, 22.205974578857422, 22.04595375061035, 21.883800506591797, 21.722026824951172, 21.5626163482666, 21.407323837280273, 21.252456665039062, 21.09724235534668, 20.938791275024414, 20.779685974121094, 20.62162971496582, 20.4669189453125, 20.3142147064209, 20.162883758544922, 20.01047706604004, 19.85685157775879, 19.703975677490234, 19.553754806518555, 19.40742301940918, 19.264314651489258, 19.122966766357422, 18.97968101501465, 18.834999084472656, 18.690229415893555, 18.546415328979492, 18.402610778808594, 18.258975982666016, 18.11440658569336, 17.970054626464844, 17.827861785888672, 17.68832015991211, 17.550716400146484, 17.415355682373047, 17.280445098876953, 17.14678192138672, 17.014076232910156, 16.883808135986328, 16.753719329833984, 16.622791290283203, 16.488920211791992, 16.35369300842285, 16.22126007080078, 16.095569610595703, 15.974079132080078, 15.853219032287598, 15.72891616821289, 15.59947395324707, 15.469198226928711, 15.342164993286133, 15.219317436218262, 15.100605964660645, 14.980583190917969, 14.860252380371094, 14.739171028137207, 14.619288444519043, 14.502647399902344, 14.389965057373047, 14.27832317352295, 14.165441513061523, 14.051509857177734, 13.93765640258789, 13.826026916503906, 13.715468406677246, 13.607650756835938, 13.496648788452148, 13.383329391479492, 13.267261505126953, 13.151922225952148, 13.038779258728027, 12.92779541015625, 12.816251754760742, 12.701602935791016, 12.584444999694824, 12.468511581420898, 12.354952812194824, 12.244566917419434, 12.135894775390625, 12.029569625854492, 11.925186157226562, 11.823380470275879, 11.723055839538574, 11.622727394104004, 11.520914077758789, 11.416922569274902, 11.312165260314941, 11.208732604980469, 11.10700798034668, 11.007889747619629, 10.909204483032227, 10.809144020080566, 10.707747459411621, 10.605745315551758, 10.506563186645508, 10.410165786743164, 10.314408302307129, 10.218300819396973, 10.121312141418457, 10.023993492126465, 9.928132057189941, 9.834404945373535, 9.744147300720215, 9.653356552124023, 9.561277389526367, 9.468231201171875, 9.374344825744629, 9.282003402709961, 9.192495346069336, 9.103446960449219, 9.015043258666992, 8.926200866699219, 8.838485717773438, 8.751630783081055, 8.665960311889648, 8.580838203430176, 8.494034767150879, 8.40644645690918, 8.319029808044434, 8.233461380004883, 8.149580001831055, 8.068504333496094, 7.987727165222168, 7.907271385192871, 7.826296806335449, 7.7446088790893555, 7.663700103759766, 7.583346843719482, 7.5041728019714355, 7.424949645996094, 7.34641170501709, 7.268475532531738, 7.190487861633301, 7.113183975219727, 7.035463333129883, 6.957459449768066, 6.880485534667969, 6.805814743041992, 6.733950138092041, 6.664864540100098, 6.596710205078125, 6.526113986968994, 6.453292369842529, 6.379787445068359, 6.308520317077637, 6.239605903625488, 6.171825408935547, 6.103394985198975, 6.033662796020508, 5.964019775390625, 5.89523458480835, 5.827916145324707, 5.761248588562012, 5.695002555847168, 5.628237724304199, 5.562036991119385, 5.497188568115234, 5.434276580810547, 5.373710632324219, 5.314513683319092, 5.255434036254883, 5.19560432434082, 5.135910987854004, 5.076885223388672, 5.018493175506592, 4.960127353668213, 4.900954246520996, 4.841784477233887, 4.784269332885742, 4.729909420013428, 4.67695951461792, 4.624540328979492, 4.571746349334717, 4.516648292541504, 4.46004056930542, 4.403772354125977, 4.347835063934326, 4.293648719787598, 4.2399163246154785, 4.187976837158203, 4.13656759262085, 4.086839199066162, 4.038580894470215, 3.991074323654175, 3.9420456886291504, 3.892064094543457, 3.8420252799987793, 3.793379306793213, 3.7477054595947266, 3.704021453857422, 3.6613032817840576, 3.6168198585510254, 3.5706210136413574, 3.5225014686584473, 3.4756243228912354, 3.4306745529174805, 3.387819528579712, 3.3469200134277344, 3.306817054748535, 3.266785144805908, 3.227121353149414, 3.188246726989746, 3.1500604152679443, 3.1109695434570312, 3.072086811065674, 3.0339338779449463, 2.9974265098571777, 2.9611244201660156, 2.9252705574035645, 2.887768268585205, 2.846853494644165, 2.8051815032958984, 2.765190601348877, 2.728276491165161, 2.6949098110198975, 2.663076639175415, 2.630977153778076, 2.5972986221313477, 2.5623912811279297, 2.5293757915496826, 2.4986140727996826, 2.4683361053466797, 2.4374566078186035, 2.4048168659210205, 2.371562957763672, 2.3394112586975098, 2.3097755908966064, 2.2830123901367188, 2.2568652629852295, 2.2302122116088867, 2.2025556564331055, 2.1735823154449463, 2.1443448066711426, 2.1145293712615967, 2.0841856002807617, 2.0533528327941895, 2.0237834453582764, 1.996132254600525, 1.9719631671905518, 1.9500808715820312, 1.9278086423873901, 1.9034643173217773, 1.8765110969543457, 1.8477627038955688, 1.8204681873321533, 1.7961366176605225, 1.7744765281677246, 1.754655361175537, 1.7349741458892822, 1.7141201496124268, 1.6927661895751953, 1.6689279079437256, 1.6446927785873413, 1.619302749633789, 1.5943386554718018, 1.5706288814544678, 1.5495566129684448, 1.5302032232284546, 1.5114333629608154, 1.491368293762207, 1.469764232635498, 1.446757435798645, 1.424503207206726, 1.4041261672973633, 1.3876874446868896, 1.37278151512146, 1.3566217422485352, 1.3387857675552368, 1.3189970254898071, 1.2991602420806885, 1.2807953357696533, 1.2628533840179443, 1.2454438209533691, 1.2298157215118408, 1.215038776397705, 1.2016723155975342, 1.1890926361083984, 1.176241159439087, 1.1628756523132324, 1.1477210521697998, 1.1322383880615234, 1.118009328842163, 1.1052467823028564, 1.0926997661590576, 1.0799956321716309, 1.0650734901428223, 1.0481101274490356, 1.0310444831848145, 1.0148855447769165, 1.0002671480178833, 0.9865742921829224, 0.9726752042770386, 0.9584107398986816, 0.9439973831176758, 0.9304401278495789, 0.9194836616516113, 0.910431444644928, 0.901872992515564, 0.8922662138938904, 0.8820143938064575, 0.8693972229957581, 0.8558073043823242, 0.8431190252304077, 0.8318266868591309, 0.8213794827461243, 0.8131312727928162, 0.8069919347763062, 0.8026072978973389, 0.7977644801139832, 0.7888401746749878, 0.7746130228042603, 0.7577764391899109, 0.7395559549331665, 0.7252950072288513, 0.7153338193893433, 0.7092757225036621, 0.7042636275291443, 0.6987306475639343, 0.6903582215309143, 0.6798433065414429, 0.6693308353424072, 0.6607911586761475, 0.6553004384040833, 0.6516624689102173, 0.6483136415481567, 0.6432492136955261, 0.637688159942627, 0.6315328478813171, 0.6241682767868042, 0.6156589984893799, 0.606623649597168, 0.5980398654937744, 0.5924094319343567, 0.5872237682342529, 0.5824892520904541, 0.5762549638748169, 0.5678626298904419, 0.5592249631881714, 0.5521055459976196, 0.5471092462539673, 0.5430928468704224, 0.5402786731719971, 0.5361867547035217, 0.530924916267395, 0.5247969627380371, 0.5195049047470093, 0.5150901079177856, 0.5105984210968018, 0.5053007006645203, 0.49872955679893494, 0.49299514293670654, 0.48698410391807556, 0.48231664299964905, 0.4776288866996765, 0.4725630283355713, 0.46631789207458496, 0.4607853293418884, 0.4565424621105194, 0.4528442919254303, 0.4492824077606201, 0.44360923767089844, 0.436774343252182, 0.4306948781013489, 0.42571407556533813, 0.42529743909835815, 0.4260627329349518, 0.42652514576911926, 0.42327821254730225, 0.4168250560760498, 0.40701374411582947, 0.39881962537765503, 0.392905592918396, 0.38991519808769226, 0.388450026512146, 0.38770100474357605, 0.3859521150588989, 0.3823249340057373, 0.3760799765586853, 0.3682119846343994, 0.35995787382125854, 0.35440850257873535, 0.35170888900756836, 0.3540180027484894, 0.35911762714385986, 0.36555320024490356, 0.36951082944869995, 0.368580162525177, 0.3634859323501587, 0.3560991585254669, 0.3505817651748657, 0.34617334604263306, 0.34385979175567627, 0.341214120388031, 0.33820509910583496, 0.33600670099258423, 0.33432748913764954, 0.3346988558769226, 0.334031879901886, 0.33080244064331055, 0.32728511095046997, 0.3241499662399292, 0.3206929564476013, 0.3173835873603821, 0.31362858414649963, 0.3082450032234192, 0.3012043535709381]

loss = loss / np.max(loss)
stage = list(range(50, 1000+1, 50))

# Define custom colors in a dark color scheme
color_blue = '#002147'
color_red = '#8C1515'

# Set the figure size
plt.figure(figsize=(10, 4))

# Plot SSIM with markers
ax1 = plt.gca()
ax1.set_ylabel('SSIM')
ax1.plot(stage, ssim, marker='o', linestyle='-', label='SSIM', color=color_blue)

# Plot Loss with markers
ax2 = plt.twinx()
ax2.set_ylabel('Relative Loss')
ax2.plot(np.arange(len(loss)), loss, label='Relative Loss', color=color_red)

# Set x-axis label
ax1.set_xlabel('Steps')

ax1.axvline(x=950, color='#777777aa', linestyle='dashed')
ax2.text(950 + 10, 0.35, 'Threshold $\mathcal{T}$', rotation=90, color='gray', fontsize=12)

ax2.set_ylim([0, np.max(loss)])

# Show legend in one box
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines + lines2, labels + labels2, loc='lower right')


# Adjust the alignment of twin axes
ax1.spines['left'].set_color(color_blue)
ax1.spines['right'].set_visible(False)
ax1.yaxis.label.set_color(color_blue)
ax1.tick_params(axis='y', colors=color_blue)
ax2.spines['right'].set_color(color_red)
ax2.spines['left'].set_visible(False)
ax2.yaxis.label.set_color(color_red)
ax2.tick_params(axis='y', colors=color_red)
# Set the title

# Add grid
plt.grid(color='white', linestyle='-', linewidth=0.5)

# Set the layout and padding
plt.tight_layout(pad=1.0)

plt.title("CeleA-HQ")
# Save the figure
plt.savefig('run/ssim_loss2.pdf', dpi=300)
plt.savefig('run/ssim_loss2.png', dpi=300)

plt.close()



================================================
FILE: ddpm_exp/extract_cifar10.py
================================================
import os
import torchvision
from torchvision.datasets import CIFAR10
from tqdm import tqdm

# Define the path to the folder where the images will be saved
save_path = 'data/cifar10/images'

# Create the folder if it doesn't exist
if not os.path.exists(save_path):
    os.makedirs(save_path)

# Load the CIFAR10 dataset
dataset = CIFAR10(root='data/cifar10', train=True, download=True)

# Loop through the dataset and save each image to the folder
for i in tqdm(range(len(dataset))):
    image, label = dataset[i]
    image_name = f'{i}.png'
    image_path = os.path.join(save_path, image_name)
    image.save(image_path)

================================================
FILE: ddpm_exp/fid_score.py
================================================
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs

The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
of these distributions, while the 2nd distribution is given by a GAN.

When run as a stand-alone program, it compares the distribution of
images that are stored as PNG/JPEG at a specified location with a
distribution given by summary statistics (in pickle format).

The FID is calculated by assuming that X_1 and X_2 are the activations of
the pool_3 layer of the inception net for generated samples and real world
samples respectively.

See --help to see further details.

Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
of Tensorflow

Copyright 2018 Institute of Bioinformatics, JKU Linz

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

import numpy as np
import torch
import torchvision.transforms as TF
from PIL import Image
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d

try:
    from tqdm import tqdm
except ImportError:
    # If tqdm is not available, provide a mock version of it
    def tqdm(x):
        return x

from inception import InceptionV3

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=50,
                    help='Batch size to use')
parser.add_argument('--dataset_name', type=str, default=None)
parser.add_argument('--num-workers', type=int,
                    help=('Number of processes to use for data loading. '
                          'Defaults to `min(8, num_cpus)`'))
parser.add_argument('--device', type=str, default=None,
                    help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument('--dims', type=int, default=2048,
                    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
                    help=('Dimensionality of Inception features to use. '
                          'By default, uses pool3 features'))
parser.add_argument('--num_samples', type=int, default=None,
                    help=('Number of samples for FID estimation'))
parser.add_argument('--res', type=int, default=None,
                    help=('Resolutions of samples for FID estimation'))
parser.add_argument('--save-stats', action='store_true',
                    help=('Generate an npz archive from a directory of samples. '
                          'The first path is used as input and the second as output.'))

parser.add_argument('path', type=str, nargs=2,
                    help=('Paths to the generated images or '
                          'to .npz statistic files'))


IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
                    'tif', 'tiff', 'webp'}


class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None):
        self.files = files
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        return img


def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
                    num_workers=1, res=None, dataset_name=None):
    """Calculates the activations of the pool_3 layer for all images.

    Params:
    -- files       : List of image files paths
    -- model       : Instance of inception model
    -- batch_size  : Batch size of images for the model to process at once.
                     Make sure that the number of samples is a multiple of
                     the batch size, otherwise some samples are ignored. This
                     behavior is retained to match the original FID score
                     implementation.
    -- dims        : Dimensionality of features returned by Inception
    -- device      : Device to run calculations
    -- num_workers : Number of parallel dataloader workers

    Returns:
    -- A numpy array of dimension (num images, dims) that contains the
       activations of the given tensor when feeding inception with the
       query tensor.
    """
    model.eval()

    if batch_size > len(files):
        print(('Warning: batch size is bigger than the data size. '
               'Setting batch size to data size'))
        batch_size = len(files)

    if res is None:
        trans = TF.ToTensor()
    else:
        if dataset_name == 'celeba':
            from datasets import Crop
            cx = 89
            cy = 121
            x1 = cy - 64
            x2 = cy + 64
            y1 = cx - 64
            y2 = cx + 64
            trans = TF.Compose([
                        Crop(x1, x2, y1, y2),
                        TF.Resize(res),
                        TF.ToTensor(),
            ])
        else:
            trans = TF.Compose([
                TF.Resize(res),
                TF.CenterCrop(res),
                TF.ToTensor()
            ])
    
    dataset = ImagePathDataset(files, transforms=trans)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             drop_last=False,
                                             num_workers=num_workers)

    pred_arr = np.empty((len(files), dims))

    start_idx = 0

    for batch in tqdm(dataloader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

        pred = pred.squeeze(3).squeeze(2).cpu().numpy()

        pred_arr[start_idx:start_idx + pred.shape[0]] = pred

        start_idx = start_idx + pred.shape[0]

    return pred_arr


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1)
            + np.trace(sigma2) - 2 * tr_covmean)


def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
                                    device='cpu', num_workers=1, res=None, dataset_name=None):
    """Calculation of the statistics used by the FID.
    Params:
    -- files       : List of image files paths
    -- model       : Instance of inception model
    -- batch_size  : The images numpy array is split into batches with
                     batch size batch_size. A reasonable batch size
                     depends on the hardware.
    -- dims        : Dimensionality of features returned by Inception
    -- device      : Device to run calculations
    -- num_workers : Number of parallel dataloader workers

    Returns:
    -- mu    : The mean over samples of the activations of the pool_3 layer of
               the inception model.
    -- sigma : The covariance matrix of the activations of the pool_3 layer of
               the inception model.
    """
    act = get_activations(files, model, batch_size, dims, device, num_workers, res=res, dataset_name=dataset_name)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


def compute_statistics_of_path(path, model, batch_size, dims, device,
                               num_workers=1, num_samples=None, res=None, dataset_name=None):
    if path.endswith('.npz'):
        with np.load(path) as f:
            m, s = f['mu'][:], f['sigma'][:]
    else:
        path = pathlib.Path(path)

        files = sorted([file for ext in IMAGE_EXTENSIONS
                       for file in path.glob('**/*.{}'.format(ext))])
        if num_samples is not None:
            #import random
            #files = random.sample(files, num_samples)
            files = files[:num_samples]
        print("Found %d files." % len(files))
        m, s = calculate_activation_statistics(files, model, batch_size,
                                               dims, device, num_workers, res=res, dataset_name=dataset_name)

    return m, s


def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1, num_samples=None, res=None, dataset_name=None):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx]).to(device)

    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
                                        dims, device, num_workers, num_samples=num_samples, res=res, dataset_name=dataset_name)
    m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
                                        dims, device, num_workers, num_samples=num_samples, res=res, dataset_name=dataset_name)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value


def save_fid_stats(paths, batch_size, device, dims, num_workers=1, num_samples=None, res=None, dataset_name=None):
    """Calculates the FID of two paths"""
    if not os.path.exists(paths[0]):
        raise RuntimeError('Invalid path: %s' % paths[0])

    if os.path.exists(paths[1]):
        raise RuntimeError('Existing output file: %s' % paths[1])

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx]).to(device)

    print(f"Saving statistics for {paths[0]}")

    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
                                        dims, device, num_workers, num_samples=num_samples, res=res, dataset_name=dataset_name)

    np.savez_compressed(paths[1], mu=m1, sigma=s1)


def main():
    args = parser.parse_args()

    if args.device is None:
        device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
    else:
        device = torch.device(args.device)

    if args.num_workers is None:
        try:
            num_cpus = len(os.sched_getaffinity(0))
        except AttributeError:
            # os.sched_getaffinity is not available under Windows, use
            # os.cpu_count instead (which may not return the *available* number
            # of CPUs).
            num_cpus = os.cpu_count()

        num_workers = min(num_cpus, 8) if num_cpus is not None else 0
    else:
        num_workers = args.num_workers

    if args.save_stats:
        save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers, num_samples=args.num_samples, res=args.res, dataset_name=args.dataset_name)
        return

    fid_value = calculate_fid_given_paths(args.path,
                                          args.batch_size,
                                          device,
                                          args.dims,
                                          num_workers,
                                          num_samples=args.num_samples,
                                          res = args.res, dataset_name=args.dataset_name)
    print('FID: ', fid_value)


if __name__ == '__main__':
    main()


================================================
FILE: ddpm_exp/finetune.py
================================================
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb
from tqdm import tqdm
from runners.diffusion import Diffusion
from torchvision import transforms
import torchvision
from datasets import get_dataset, data_transform, inverse_data_transform
from utils import UnlabeledImageFolder
from accelerate import Accelerator
torch.set_printoptions(sci_mode=False)


def parse_args_and_config(accelerator):
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=2333, help="Random seed")
    parser.add_argument("--taylor_batch_size", type=int, default=128, help="batch size for taylor expansion")
    parser.add_argument(
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(
        "--kd",
        action="store_true",
        default=False,
        help="skip according to (uniform or quadratic)",
    )
    parser.add_argument(
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(
        "--comment", type=str, default="", help="A string for experiment comment"
    )

    parser.add_argument(
        "--load_pruned_model", type=str, default=None, help="load pruned models"
    )

    parser.add_argument(
        "--save_pruned_model", type=str, default=None, help="load pruned models"
    )

    parser.add_argument(
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")
    parser.add_argument(
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")
    parser.add_argument("--interpolation", action="store_true")
    parser.add_argument(
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument(
        "--use_ema",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")
    parser.add_argument(
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )

    parser.add_argument(
        "--pruner",
        type=str,
        default="taylor",
        choices=["taylor", "random", "magnitude", "reinit", "first_order_taylor", "second_order_taylor"],
    )

    parser.add_argument(
        "--restore_from",
        type=str,
        default=None,
        help="Restore from user a checkpoint",
    )
    parser.add_argument(
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument(
        "--pruning_ratio",
        type=float,
        default=0.0,
        help="pruning ratio",
    )
    
    parser.add_argument("--sequence", action="store_true")

    args = parser.parse_args()
    args.log_path = os.path.join(args.exp, "logs", args.doc)
    
    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

        #tb_path = os.path.join(args.exp, "tensorboard", args.doc)
    if accelerator.is_main_process:
        if not args.test and not args.sample:
            if not args.resume_training:
                if os.path.exists(args.log_path):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input("Folder already exists. Overwrite? (Y/N)")
                        if response.upper() == "Y":
                            overwrite = True

                    if overwrite:
                        shutil.rmtree(args.log_path)
                        #shutil.rmtree(tb_path)
                        os.makedirs(args.log_path)
                        #if os.path.exists(tb_path):
                        #    shutil.rmtree(tb_path)
                    else:
                        print("Folder exists. Program halted.")
                        sys.exit(0)
                else:
                    os.makedirs(args.log_path)

                with open(os.path.join(args.log_path, "config.yml"), "w") as f:
                    yaml.dump(new_config, f, default_flow_style=False)
            os.makedirs(os.path.join(args.log_path, 'vis'), exist_ok=True)
            #new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
            # setup logger
            level = getattr(logging, args.verbose.upper(), None)
            if not isinstance(level, int):
                raise ValueError("level {} not supported".format(args.verbose))

            handler1 = logging.StreamHandler()
            handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
            formatter = logging.Formatter(
                "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
            )
            handler1.setFormatter(formatter)
            handler2.setFormatter(formatter)
            logger = logging.getLogger()
            logger.addHandler(handler1)
            logger.addHandler(handler2)
            logger.setLevel(level)
        else:
            level = getattr(logging, args.verbose.upper(), None)
            if not isinstance(level, int):
                raise ValueError("level {} not supported".format(args.verbose))

            handler1 = logging.StreamHandler()
            formatter = logging.Formatter(
                "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
            )
            handler1.setFormatter(formatter)
            logger = logging.getLogger()
            logger.addHandler(handler1)
            logger.setLevel(level)

    if args.sample:
        os.makedirs(os.path.join(args.exp, "image_samples", args.image_folder, str(accelerator.process_index)), exist_ok=True)
        args.image_folder = os.path.join(
            args.exp, "image_samples", args.image_folder, str(accelerator.process_index)
        )
        if not os.path.exists(args.image_folder):
            os.makedirs(args.image_folder)
        else:
            if not (args.fid or args.interpolation):
                overwrite = False
                if args.ni:
                    overwrite = True
                else:
                    response = input(
                        f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
                    )
                    if response.upper() == "Y":
                        overwrite = True

                if overwrite:
                    shutil.rmtree(args.image_folder)
                    os.makedirs(args.image_folder)
                else:
                    print("Output image folder exists. Program halted.")
                    sys.exit(0)

    # add device
    #device = #torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    #logging.info("Using device: {}".format(device))
    #new_config.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


def main():
    accelerator = Accelerator()
    args, config = parse_args_and_config(accelerator)
    logging.info("Writing log file to {}".format(args.log_path))
    logging.info("Exp instance id = {}".format(os.getpid()))
    logging.info("Exp comment = {}".format(args.comment))

    try:
        runner = Diffusion(args, config)
        runner.accelerator = accelerator
        if args.sample:
            runner.sample()
        elif args.test:
            runner.test()
        else:
            runner.train(kd=args.kd)
    except Exception:
        logging.error(traceback.format_exc())

    return 0


if __name__ == "__main__":
    sys.exit(main())


================================================
FILE: ddpm_exp/finetune_simple.py
================================================
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb
from tqdm import tqdm
from runners.diffusion_simple import Diffusion
from torchvision import transforms
import torchvision
from datasets import get_dataset, data_transform, inverse_data_transform

from utils import UnlabeledImageFolder

torch.set_printoptions(sci_mode=False)


def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=2333, help="Random seed")
    parser.add_argument("--taylor_batch_size", type=int, default=128, help="batch size for taylor expansion")
    parser.add_argument(
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(
        "--comment", type=str, default="", help="A string for experiment comment"
    )

    parser.add_argument(
        "--load_pruned_model", type=str, default=None, help="load pruned models"
    )

    parser.add_argument(
        "--save_pruned_model", type=str, default=None, help="load pruned models"
    )

    parser.add_argument(
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")
    parser.add_argument(
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")
    parser.add_argument("--interpolation", action="store_true")
    parser.add_argument(
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument(
        "--use_ema",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")
    parser.add_argument(
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )

    parser.add_argument(
        "--pruner",
        type=str,
        default="taylor",
        choices=["taylor", "random", "magnitude", "reinit", "first_order_taylor", "second_order_taylor", "ours"],
    )

    parser.add_argument(
        "--restore_from",
        type=str,
        default=None,
        help="Restore from user a checkpoint",
    )
    parser.add_argument(
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument(
        "--thr",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument(
        "--pruning_ratio",
        type=float,
        default=0.0,
        help="pruning ratio",
    )
    
    parser.add_argument("--sequence", action="store_true")

    args = parser.parse_args()
    args.log_path = os.path.join(args.exp, "logs", args.doc)

    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    #tb_path = os.path.join(args.exp, "tensorboard", args.doc)

    if not args.test and not args.sample:
        if not args.resume_training:
            if os.path.exists(args.log_path):
                overwrite = False
                if args.ni:
                    overwrite = True
                else:
                    response = input("Folder already exists. Overwrite? (Y/N)")
                    if response.upper() == "Y":
                        overwrite = True

                if overwrite:
                    shutil.rmtree(args.log_path)
                    #shutil.rmtree(tb_path)
                    os.makedirs(args.log_path)
                    #if os.path.exists(tb_path):
                    #    shutil.rmtree(tb_path)
                else:
                    print("Folder exists. Program halted.")
                    sys.exit(0)
            else:
                os.makedirs(args.log_path)

            with open(os.path.join(args.log_path, "config.yml"), "w") as f:
                yaml.dump(new_config, f, default_flow_style=False)
        os.makedirs(os.path.join(args.log_path, 'vis'), exist_ok=True)
        #new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
        # setup logger
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        handler2.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.addHandler(handler2)
        logger.setLevel(level)

    else:
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.setLevel(level)

        if args.sample:
            os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
            args.image_folder = os.path.join(
                args.exp, "image_samples", args.image_folder
            )
            if not os.path.exists(args.image_folder):
                os.makedirs(args.image_folder)
            else:
                if not (args.fid or args.interpolation):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input(
                            f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
                        )
                        if response.upper() == "Y":
                            overwrite = True

                    if overwrite:
                        shutil.rmtree(args.image_folder)
                        os.makedirs(args.image_folder)
                    else:
                        print("Output image folder exists. Program halted.")
                        sys.exit(0)

    # add device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


def main():
    args, config = parse_args_and_config()
    logging.info("Writing log file to {}".format(args.log_path))
    logging.info("Exp instance id = {}".format(os.getpid()))
    logging.info("Exp comment = {}".format(args.comment))

    try:
        runner = Diffusion(args, config)
        if args.pruning_ratio > 0 and args.load_pruned_model is None:
            # Dataset 
            print(config)
            dataset, _ = get_dataset(args, config)
            print(f"Dataset size: {len(dataset)}")
            train_dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.taylor_batch_size, shuffle=True, num_workers=4, drop_last=True
            )

            from models.diffusion import AttnBlock
            import torch_pruning as tp
            print("Pruning ...")
            model = runner.model
            model.to(runner.device)
           
            example_inputs = {'x': torch.randn(1, 3, config.data.image_size, config.data.image_size).to(runner.device), 't': torch.ones(1).to(runner.device)}

            if args.pruner == 'taylor':
                imp = tp.importance.TaylorImportance()
            elif args.pruner == 'first_order_taylor':
                imp = tp.importance.FullTaylorImportance(order=1)
            elif args.pruner == 'second_order_taylor':
                imp = tp.importance.FullTaylorImportance(order=2)
            elif args.pruner == 'random' or args.pruner == 'reinit':
                imp = tp.importance.RandomImportance()
            elif args.pruner == 'magnitude':
                imp = tp.importance.MagnitudeImportance()
            elif args.pruner == 'ours':
                imp = tp.importance.TaylorImportance()

            ignored_layers = [model.conv_out]
            channel_groups = {}
            iterative_steps = 1
            pruner = tp.pruner.MagnitudePruner(
                model,
                example_inputs,
                importance=imp,
                iterative_steps=iterative_steps,
                channel_groups =channel_groups,
                ch_sparsity=args.pruning_ratio, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
                ignored_layers=ignored_layers,
                root_module_types=[torch.nn.Conv2d, torch.nn.Linear]
            )
            base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)

            if 'taylor' in args.pruner or args.pruner=='ours':
                x = next(iter(train_dataloader))
                if isinstance(x, (list, tuple)):
                    x = x[0]
                x = x.to(runner.device)
                x = data_transform(config, x)
                n = x.size(0)
                e = torch.randn_like(x)
                b = runner.betas
                #t = torch.randint(
                #        low=0, high=runner.num_timesteps, size=(n // 2 + 1,)
                #).to(runner.device)
                #t = torch.cat([t, runner.num_timesteps - t - 1], dim=0)[:n]
                from functions.losses import loss_registry
                
                model.zero_grad()
                max_loss = 0
                for step_k in tqdm(range(1000)):
                    t = torch.ones(n, dtype=torch.long).to(runner.device)*step_k
                    loss = loss_registry[config.model.type](model, x, t, e, b)
                    if args.pruner == 'ours':
                        if loss>max_loss:
                            max_loss = loss
                        if loss<max_loss*args.thr:
                            break
                        #print(loss, max_loss)
                    loss.backward()

            print("============ Before Pruning ============")
            print(model)
            for g in pruner.step(interactive=True):
                g.prune()
            
            if args.pruner == 'reinit':
                def reset_parameters(model):
                    for m in model.modules():
                        if hasattr(m, 'reset_parameters'):
                            m.reset_parameters()
                model.apply(reset_parameters)
            
            macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
            print("============ After Pruning ============")
            print(model)
            print("#Params: {:.4f} M => {:.4f} M".format(base_nparams/1e6, nparams/1e6))
            print("#MACs: {:.4f} G => {:.4f} G".format(base_macs/1e9, macs/1e9))
            del pruner
            # Save pruned model
            print("Saving pruned model as {}".format(os.path.join(args.log_path, "pruned_model.pth")))
            torch.save(
                model,
                os.path.join(args.log_path, "pruned_model.pth"),
            )
        
        if args.load_pruned_model is not None:
            print("Loading pruned model from {}".format(args.load_pruned_model))
            model = torch.load(args.load_pruned_model, map_location='cpu')
            runner.model = model
            
        print(step_k)
        if args.sample:
            runner.sample()
        elif args.test:
            runner.test()
        else:
            runner.train()
    except Exception:
        logging.error(traceback.format_exc())

    return 0


if __name__ == "__main__":
    sys.exit(main())

================================================
FILE: ddpm_exp/functions/__init__.py
================================================
import torch.optim as optim


def get_optimizer(config, parameters):
    if config.optim.optimizer == 'Adam':
        return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
                          betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad,
                          eps=config.optim.eps)
    elif config.optim.optimizer == 'RMSProp':
        return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'SGD':
        return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
    else:
        raise NotImplementedError(
            'Optimizer {} not understood.'.format(config.optim.optimizer))


================================================
FILE: ddpm_exp/functions/ckpt_util.py
================================================
import os, hashlib
import requests
from tqdm import tqdm

URL_MAP = {
    "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1",
    "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1",
    "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1",
    "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1",
    "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1",
    "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1",
    "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1",
    "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1",
}
CKPT_MAP = {
    "cifar10": "diffusion_cifar10_model/model-790000.ckpt",
    "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt",
    "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt",
    "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt",
    "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt",
    "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt",
    "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt",
    "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt",
    "celeba": "ema_diffusion_celeba_model/model.ckpt",
    "ema_celeba": "ema_diffusion_celeba_model/model.ckpt",
}

MD5_MAP = {
    "cifar10": "82ed3067fd1002f5cf4c339fb80c4669",
    "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3",
    "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c",
    "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f",
    "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b",
    "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558",
    "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3",
    "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f",
}


def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def md5_hash(path):
    with open(path, "rb") as f:
        content = f.read()
    return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root=None, check=False):
    if 'church_outdoor' in name:
        name = name.replace('church_outdoor', 'church')
    #assert name in URL_MAP
    # Modify the path when necessary
    cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("./run/cache"))
    root = (
        root
        if root is not None
        else os.path.join(cachedir, "diffusion_models_converted")
    )
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)
        md5 = md5_hash(path)
        assert md5 == MD5_MAP[name], md5
    return path


================================================
FILE: ddpm_exp/functions/denoising.py
================================================
import torch


def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


def generalized_steps(x, seq, model, b, **kwargs):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(b, t.long())
            at_next = compute_alpha(b, next_t.long())
            xt = xs[-1].to('cuda')
            et = model(xt, t)
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
            x0_preds.append(x0_t.to('cpu'))
            c1 = (
                kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
            )
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
            xs.append(xt_next.to('cpu'))

    return xs, x0_preds


def ddpm_steps(x, seq, model, b, **kwargs):
    with torch.no_grad():
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        xs = [x]
        x0_preds = []
        betas = b
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(betas, t.long())
            atm1 = compute_alpha(betas, next_t.long())
            beta_t = 1 - at / atm1
            x = xs[-1].to('cuda')

            output = model(x, t.float())
            e = output

            x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e
            x0_from_e = torch.clamp(x0_from_e, -1, 1)
            x0_preds.append(x0_from_e.to('cpu'))
            mean_eps = (
                (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x
            ) / (1.0 - at)

            mean = mean_eps
            noise = torch.randn_like(x)
            mask = 1 - (t == 0).float()
            mask = mask.view(-1, 1, 1, 1)
            logvar = beta_t.log()
            sample = mean + mask * torch.exp(0.5 * logvar) * noise
            xs.append(sample.to('cpu'))
    return xs, x0_preds


================================================
FILE: ddpm_exp/functions/losses.py
================================================
import torch


def noise_estimation_loss(model,
                          x0: torch.Tensor,
                          t: torch.LongTensor,
                          e: torch.Tensor,
                          b: torch.Tensor, keepdim=False):
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model(x, t.float())
    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3))
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)

def noise_estimation_kd_loss(model,
                             teacher,
                          x0: torch.Tensor,
                          t: torch.LongTensor,
                          e: torch.Tensor,
                          b: torch.Tensor, keepdim=False):
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model(x, t.float())
    with torch.no_grad():
        teacher_output = teacher(x, t.float())
    if keepdim:
        return 0.7*(teacher_output - output).square().sum(dim=(1, 2, 3)) + 0.3 * (e - output).square().sum(dim=(1, 2, 3))
    else:
        return 0.7*(teacher_output - output).square().sum(dim=(1, 2, 3)).mean(dim=0) + 0.3 * (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)


loss_registry = {
    'simple': noise_estimation_loss,
}


================================================
FILE: ddpm_exp/inception.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'  # noqa: E501


class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=(DEFAULT_BLOCK_INDEX,),
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False,
                 use_fid_inception=True):
        """Build pretrained InceptionV3

        Parameters
        ----------
        output_blocks : list of int
            Indices of blocks to return features of. Possible values are:
                - 0: corresponds to output of first max pooling
                - 1: corresponds to output of second max pooling
                - 2: corresponds to output which is fed to aux classifier
                - 3: corresponds to output of final average pooling
        resize_input : bool
            If true, bilinearly resizes input to width and height 299 before
            feeding input to model. As the network without fully connected
            layers is fully convolutional, it should be able to handle inputs
            of arbitrary size, so resizing might not be strictly needed
        normalize_input : bool
            If true, scales the input from range (0, 1) to the range the
            pretrained Inception network expects, namely (-1, 1)
        requires_grad : bool
            If true, parameters of the model require gradients. Possibly useful
            for finetuning the network
        use_fid_inception : bool
            If true, uses the pretrained Inception model used in Tensorflow's
            FID implementation. If false, uses the pretrained Inception model
            available in torchvision. The FID Inception model has different
            weights and a slightly different structure from torchvision's
            Inception model. If you want to compute FID scores, you are
            strongly advised to set this parameter to true to get comparable
            results.
        """
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        if use_fid_inception:
            inception = fid_inception_v3()
        else:
            inception = _inception_v3(weights='DEFAULT')

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps

        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)

        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp


def _inception_v3(*args, **kwargs):
    """Wraps `torchvision.models.inception_v3`"""
    try:
        version = tuple(map(int, torchvision.__version__.split('.')[:2]))
    except ValueError:
        # Just a caution against weird version strings
        version = (0,)

    # Skips default weight inititialization if supported by torchvision
    # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
    if version >= (0, 6):
        kwargs['init_weights'] = False

    # Backwards compatibility: `weights` argument was handled by `pretrained`
    # argument prior to version 0.13.
    if version < (0, 13) and 'weights' in kwargs:
        if kwargs['weights'] == 'DEFAULT':
            kwargs['pretrained'] = True
        elif kwargs['weights'] is None:
            kwargs['pretrained'] = False
        else:
            raise ValueError(
                'weights=={} not supported in torchvision {}'.format(
                    kwargs['weights'], torchvision.__version__
                )
            )
        del kwargs['weights']

    return torchvision.models.inception_v3(*args, **kwargs)


def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    inception = _inception_v3(num_classes=1008,
                              aux_logits=False,
                              weights=None)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception


class FIDInceptionA(torchvision.models.inception.InceptionA):
    """InceptionA block patched for FID computation"""
    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionC(torchvision.models.inception.InceptionC):
    """InceptionC block patched for FID computation"""
    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionE_1(torchvision.models.inception.InceptionE):
    """First InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionE_2(torchvision.models.inception.InceptionE):
    """Second InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: The FID Inception model uses max pooling instead of average
        # pooling. This is likely an error in this specific Inception
        # implementation, as other Inception models use average pooling here
        # (which matches the description in the paper).
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


================================================
FILE: ddpm_exp/main.py
================================================
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb

from runners.diffusion import Diffusion

torch.set_printoptions(sci_mode=False)


def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")
    parser.add_argument(
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(
        "--comment", type=str, default="", help="A string for experiment comment"
    )
    parser.add_argument(
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")
    parser.add_argument(
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")
    parser.add_argument("--interpolation", action="store_true")
    parser.add_argument(
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")
    parser.add_argument(
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )
    parser.add_argument(
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument("--sequence", action="store_true")

    args = parser.parse_args()
    args.log_path = os.path.join(args.exp, "logs", args.doc)

    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    tb_path = os.path.join(args.exp, "tensorboard", args.doc)

    if not args.test and not args.sample:
        if not args.resume_training:
            if os.path.exists(args.log_path):
                overwrite = False
                if args.ni:
                    overwrite = True
                else:
                    response = input("Folder already exists. Overwrite? (Y/N)")
                    if response.upper() == "Y":
                        overwrite = True

                if overwrite:
                    shutil.rmtree(args.log_path)
                    shutil.rmtree(tb_path)
                    os.makedirs(args.log_path)
                    if os.path.exists(tb_path):
                        shutil.rmtree(tb_path)
                else:
                    print("Folder exists. Program halted.")
                    sys.exit(0)
            else:
                os.makedirs(args.log_path)

            with open(os.path.join(args.log_path, "config.yml"), "w") as f:
                yaml.dump(new_config, f, default_flow_style=False)

        new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
        # setup logger
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt"))
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        handler2.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.addHandler(handler2)
        logger.setLevel(level)

    else:
        level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(level, int):
            raise ValueError("level {} not supported".format(args.verbose))

        handler1 = logging.StreamHandler()
        formatter = logging.Formatter(
            "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
        )
        handler1.setFormatter(formatter)
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.setLevel(level)

        if args.sample:
            os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True)
            args.image_folder = os.path.join(
                args.exp, "image_samples", args.image_folder
            )
            if not os.path.exists(args.image_folder):
                os.makedirs(args.image_folder)
            else:
                if not (args.fid or args.interpolation):
                    overwrite = False
                    if args.ni:
                        overwrite = True
                    else:
                        response = input(
                            f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
                        )
                        if response.upper() == "Y":
                            overwrite = True

                    if overwrite:
                        shutil.rmtree(args.image_folder)
                        os.makedirs(args.image_folder)
                    else:
                        print("Output image folder exists. Program halted.")
                        sys.exit(0)

    # add device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


def main():
    args, config = parse_args_and_config()
    logging.info("Writing log file to {}".format(args.log_path))
    logging.info("Exp instance id = {}".format(os.getpid()))
    logging.info("Exp comment = {}".format(args.comment))

    try:
        runner = Diffusion(args, config)
        if args.sample:
            runner.sample()
        elif args.test:
            runner.test()
        else:
            runner.train()
    except Exception:
        logging.error(traceback.format_exc())

    return 0


if __name__ == "__main__":
    sys.exit(main())


================================================
FILE: ddpm_exp/models/diffusion.py
================================================
import math
import torch
import torch.nn as nn


def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(
            x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0, 1, 0, 1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        self.temb_proj = torch.nn.Linear(temb_channels,
                                         out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)   # b,hw,c
        k = k.reshape(b, c, h*w)  # b,c,hw
        w_ = torch.bmm(q, k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b, c, h*w)
        w_ = w_.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
        # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = torch.bmm(v, w_)
        h_ = h_.reshape(b, c, h, w)

        h_ = self.proj_out(h_)

        return x+h_


class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
        num_res_blocks = config.model.num_res_blocks
        attn_resolutions = config.model.attn_resolutions
        dropout = config.model.dropout
        in_channels = config.model.in_channels
        resolution = config.data.image_size
        resamp_with_conv = config.model.resamp_with_conv
        num_timesteps = config.diffusion.num_diffusion_timesteps
        
        if config.model.type == 'bayesian':
            self.logvar = nn.Parameter(torch.zeros(num_timesteps))
        
        self.ch = ch
        self.temb_ch = self.ch*4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # timestep embedding
        self.temb = nn.Module()
        self.temb.dense = nn.ModuleList([
            torch.nn.Linear(self.ch,
                            self.temb_ch),
            torch.nn.Linear(self.temb_ch,
                            self.temb_ch),
        ])

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+ch_mult
        self.down = nn.ModuleList()
        block_in = None
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            skip_in = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                if i_block == self.num_res_blocks:
                    skip_in = ch*in_ch_mult[i_level]
                block.append(ResnetBlock(in_channels=block_in+skip_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up)  # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x, t):
        assert x.shape[2] == x.shape[3] == self.resolution

        # timestep embedding
        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb.dense[0](temb)
        temb = nonlinearity(temb)
        temb = self.temb.dense[1](temb)

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](
                    torch.cat([h, hs.pop()], dim=1), temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h


================================================
FILE: ddpm_exp/models/ema.py
================================================
import torch.nn as nn


class EMAHelper(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = []

    def to(self, device=None) -> None:
        self.shadow = [
            p.to(device=device)
            for p in self.shadow
        ]

    def copy_to(self, parameters) -> None:
        parameters = list(parameters)
        for s_param, param in zip(self.shadow, parameters):
            param.data.copy_(s_param.to(param.device).data)

    def store(self, parameters) -> None:
        r"""
        Args:
        Save the current parameters for restoring later.
            parameters: Iterable of `torch.nn.Parameter`; the parameters to be
                temporarily stored.
        """
        self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]

    def restore(self, parameters) -> None:
        if self.temp_stored_params is None:
            raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
        for c_param, param in zip(self.temp_stored_params, parameters):
            param.data.copy_(c_param.data)
        self.temp_stored_params = None

    def register(self, module):
        for param in module.parameters():
            if param.requires_grad:
                self.shadow.append(param.data.clone())

    def update(self, module):
        for i, (shadow_param, param) in enumerate(zip(self.shadow, module.parameters())):
            if param.requires_grad:
                shadow_param.data = (
                    1. - self.mu) * param.data + self.mu * shadow_param.data
                #if i==0:
                #    print(shadow_param.flatten()[0])
                    
    def ema(self, module):
        for shadow_param, param in zip(self.shadow, module.parameters()):
            if param.requires_grad:
                param.data.copy_(shadow_param.data)

    #def ema_copy(self, module):
    #    if isinstance(module, nn.parallel.DistributedDataParallel):
    #        from copy import deepcopy
    #        inner_module = module.module
    #        module_copy = deepcopy(inner_module).to(inner_module.config.device)
    ##        module_copy.load_state_dict(inner_module.state_dict())
    #        module_copy = nn.DistributedDataParallel(module_copy)
    #    else:
    #        module_copy = deepcopy(inner_module).to(module.config.device)
    #        module_copy.load_state_dict(module.state_dict())
    #    # module_copy = copy.deepcopy(module)
    #    self.ema(module_copy)
    #    return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        if isinstance(state_dict, list):
            self.shadow = state_dict
        else:
            self.shadow = state_dict.values()

================================================
FILE: ddpm_exp/prune.py
================================================
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb
from tqdm import tqdm
from runners.diffusion import Diffusion
from torchvision import transforms
import torchvision
from datasets import get_dataset, data_transform, inverse_data_transform
import torchvision.utils as tvu
from utils import UnlabeledImageFolder

torch.set_printoptions(sci_mode=False)


def parse_args_and_config():
    parser = argparse.ArgumentParser(description=globals()["__doc__"])

    parser.add_argument(
        "--config", type=str, required=True, help="Path to the config file"
    )
    parser.add_argument("--seed", type=int, default=2333, help="Random seed")
    parser.add_argument("--taylor_batch_size", type=int, default=128, help="batch size for taylor expansion")
    parser.add_argument(
        "--exp", type=str, default="exp", help="Path for saving running related data."
    )
    parser.add_argument(
        "--doc",
        type=str,
        required=True,
        help="A string for documentation purpose. "
        "Will be the name of the log folder.",
    )
    parser.add_argument(
        "--comment", type=str, default="", help="A string for experiment comment"
    )

    parser.add_argument(
        "--load_pruned_model", type=str, default=None, help="load pruned models"
    )

    parser.add_argument(
        "--save_pruned_model", type=str, default=None, help="load pruned models"
    )

    parser.add_argument(
        "--verbose",
        type=str,
        default="info",
        help="Verbose level: info | debug | warning | critical",
    )
    parser.add_argument("--test", action="store_true", help="Whether to test the model")
    parser.add_argument(
        "--sample",
        action="store_true",
        help="Whether to produce samples from the model",
    )
    parser.add_argument("--fid", action="store_true")
    parser.add_argument("--interpolation", action="store_true")
    parser.add_argument(
        "--resume_training", action="store_true", help="Whether to resume training"
    )
    parser.add_argument(
        "-i",
        "--image_folder",
        type=str,
        default="images",
        help="The folder name of samples",
    )
    parser.add_argument(
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument(
        "--use_generated_samples",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument(
        "--use_ema",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    parser.add_argument("--use_pretrained", action="store_true")
    parser.add_argument(
        "--sample_type",
        type=str,
        default="generalized",
        help="sampling approach (generalized or ddpm_noisy)",
    )
    parser.add_argument(
        "--skip_type",
        type=str,
        default="uniform",
        help="skip according to (uniform or quadratic)",
    )

    parser.add_argument(
        "--pruner",
        type=str,
        default="taylor",
        choices=["taylor", "random", "magnitude", "reinit", "first_order_taylor", "second_order_taylor", 'abs_taylor', 'fisher', 'ours'],
    )

    parser.add_argument(
        "--restore_from",
        type=str,
        default=None,
        help="Restore from user a checkpoint",
    )
    parser.add_argument(
        "--timesteps", type=int, default=1000, help="number of steps involved"
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=0.0,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument(
        "--thr",
        type=float,
        default=0.01,
        help="eta used to control the variances of sigma",
    )
    parser.add_argument(
        "--pruning_ratio",
        type=float,
        default=0.0,
        help="pruning ratio",
    )
    
    parser.add_argument("--sequence", action="store_true")

    args = parser.parse_args()
    args.log_path = os.path.join(args.exp, "logs", args.doc)

    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    # add device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info("Using device: {}".format(device))
    new_config.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cu
Download .txt
gitextract_bsd98edx/

├── .gitignore
├── LICENSE
├── README.md
├── ddpm_exp/
│   ├── .gitignore
│   ├── LICENSE
│   ├── README.md
│   ├── calc_fid.py
│   ├── compute_flops.py
│   ├── compute_pruned_ssim_curve.py
│   ├── compute_ssim.py
│   ├── compute_ssim_vis.py
│   ├── configs/
│   │   ├── bedroom.yml
│   │   ├── celeba.yml
│   │   ├── church.yml
│   │   ├── cifar10.yml
│   │   └── cifar10_pruning.yml
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── celeba.py
│   │   ├── ffhq.py
│   │   ├── lsun.py
│   │   ├── utils.py
│   │   └── vision.py
│   ├── draw_ssim_pruned_curve.py
│   ├── extract_cifar10.py
│   ├── fid_score.py
│   ├── finetune.py
│   ├── finetune_simple.py
│   ├── functions/
│   │   ├── __init__.py
│   │   ├── ckpt_util.py
│   │   ├── denoising.py
│   │   └── losses.py
│   ├── inception.py
│   ├── main.py
│   ├── models/
│   │   ├── diffusion.py
│   │   └── ema.py
│   ├── prune.py
│   ├── prune_kd.py
│   ├── prune_ssim.py
│   ├── prune_test.py
│   ├── runners/
│   │   ├── __init__.py
│   │   ├── diffusion.py
│   │   └── diffusion_simple.py
│   ├── scripts/
│   │   ├── finetune_bedroom_ddpm.sh
│   │   ├── finetune_celeba_ddpm.sh
│   │   ├── finetune_celeba_ddpm_kd.sh
│   │   ├── finetune_church_ddpm.sh
│   │   ├── finetune_cifar_ddpm.sh
│   │   ├── finetune_cifar_ddpm_kd.sh
│   │   ├── finetune_cifar_ddpm_random.sh
│   │   ├── finetune_cifar_ddpm_taylor.sh
│   │   ├── old/
│   │   │   ├── run_bedroom_sample_pratrained.sh
│   │   │   ├── run_celeba_pruning_scratch.sh
│   │   │   ├── run_celeba_pruning_taylor.sh
│   │   │   ├── run_celeba_sample_pratrained.sh
│   │   │   ├── run_church_pruning_taylor.sh
│   │   │   ├── run_cifar_pruning_first_order_taylor.sh
│   │   │   ├── run_cifar_pruning_magnitude.sh
│   │   │   ├── run_cifar_pruning_random.sh
│   │   │   ├── run_cifar_pruning_random_kd.sh
│   │   │   ├── run_cifar_pruning_scratch.sh
│   │   │   ├── run_cifar_pruning_second_order_taylor.sh
│   │   │   ├── run_cifar_pruning_taylor.sh
│   │   │   ├── run_cifar_pruning_taylor_kd.sh
│   │   │   └── run_cifar_train.sh
│   │   ├── prune_bedroom_ddpm.sh
│   │   ├── prune_bedroom_ddpm_test.sh
│   │   ├── prune_celeba_ddpm.sh
│   │   ├── prune_celeba_ddpm_ssim.sh
│   │   ├── prune_church_ddpm.sh
│   │   ├── prune_church_ddpm_test.sh
│   │   ├── prune_cifar_ddpm.sh
│   │   ├── prune_cifar_ddpm_ssim.sh
│   │   ├── prune_cifar_ddpm_test.sh
│   │   ├── run_celeba.sh
│   │   ├── sample_bedroom_ddpm_pretrained.sh
│   │   ├── sample_bedroom_ddpm_pruning.sh
│   │   ├── sample_celeba_ddpm_pruning.sh
│   │   ├── sample_celeba_pretrained.sh
│   │   ├── sample_church_ddpm_pruning.sh
│   │   ├── sample_church_ddpm_pruning_old.sh
│   │   ├── sample_church_ddpm_test.sh
│   │   ├── sample_church_pretrained.sh
│   │   ├── sample_cifar_ddpm_pruning.sh
│   │   ├── sample_cifar_pretrained.sh
│   │   ├── simple_celeba_our.sh
│   │   └── simple_cifar_our.sh
│   ├── tools/
│   │   ├── extract_cifar10.py
│   │   └── transform_weights.py
│   ├── torch_pruning/
│   │   ├── __init__.py
│   │   ├── _helpers.py
│   │   ├── dependency.py
│   │   ├── importance.py
│   │   ├── ops.py
│   │   ├── pruner/
│   │   │   ├── __init__.py
│   │   │   ├── algorithms/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── batchnorm_scale_pruner.py
│   │   │   │   ├── group_norm_pruner.py
│   │   │   │   ├── magnitude_based_pruner.py
│   │   │   │   ├── metapruner.py
│   │   │   │   ├── scaling_factor_pruner.py
│   │   │   │   ├── scheduler.py
│   │   │   │   └── taylor_pruner.py
│   │   │   └── function.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── op_counter.py
│   │       └── utils.py
│   └── utils.py
├── ddpm_prune.py
├── ddpm_sample.py
├── ddpm_train.py
├── diffusers/
│   ├── __init__.py
│   ├── commands/
│   │   ├── __init__.py
│   │   ├── diffusers_cli.py
│   │   └── env.py
│   ├── configuration_utils.py
│   ├── dependency_versions_check.py
│   ├── dependency_versions_table.py
│   ├── experimental/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   └── rl/
│   │       ├── __init__.py
│   │       └── value_guided_sampling.py
│   ├── image_processor.py
│   ├── loaders.py
│   ├── models/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── attention.py
│   │   ├── attention_flax.py
│   │   ├── attention_processor.py
│   │   ├── autoencoder_kl.py
│   │   ├── controlnet.py
│   │   ├── controlnet_flax.py
│   │   ├── cross_attention.py
│   │   ├── dual_transformer_2d.py
│   │   ├── embeddings.py
│   │   ├── embeddings_flax.py
│   │   ├── modeling_flax_pytorch_utils.py
│   │   ├── modeling_flax_utils.py
│   │   ├── modeling_pytorch_flax_utils.py
│   │   ├── modeling_utils.py
│   │   ├── prior_transformer.py
│   │   ├── resnet.py
│   │   ├── resnet_flax.py
│   │   ├── t5_film_transformer.py
│   │   ├── transformer_2d.py
│   │   ├── transformer_temporal.py
│   │   ├── unet_1d.py
│   │   ├── unet_1d_blocks.py
│   │   ├── unet_2d.py
│   │   ├── unet_2d_blocks.py
│   │   ├── unet_2d_blocks_flax.py
│   │   ├── unet_2d_condition.py
│   │   ├── unet_2d_condition_flax.py
│   │   ├── unet_3d_blocks.py
│   │   ├── unet_3d_condition.py
│   │   ├── vae.py
│   │   ├── vae_flax.py
│   │   └── vq_model.py
│   ├── optimization.py
│   ├── pipeline_utils.py
│   ├── pipelines/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── alt_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── modeling_roberta_series.py
│   │   │   ├── pipeline_alt_diffusion.py
│   │   │   └── pipeline_alt_diffusion_img2img.py
│   │   ├── audio_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── mel.py
│   │   │   └── pipeline_audio_diffusion.py
│   │   ├── audioldm/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_audioldm.py
│   │   ├── controlnet/
│   │   │   ├── __init__.py
│   │   │   ├── multicontrolnet.py
│   │   │   ├── pipeline_controlnet.py
│   │   │   ├── pipeline_controlnet_img2img.py
│   │   │   ├── pipeline_controlnet_inpaint.py
│   │   │   └── pipeline_flax_controlnet.py
│   │   ├── dance_diffusion/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_dance_diffusion.py
│   │   ├── ddim/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_ddim.py
│   │   ├── ddpm/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_ddpm.py
│   │   ├── deepfloyd_if/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_if.py
│   │   │   ├── pipeline_if_img2img.py
│   │   │   ├── pipeline_if_img2img_superresolution.py
│   │   │   ├── pipeline_if_inpainting.py
│   │   │   ├── pipeline_if_inpainting_superresolution.py
│   │   │   ├── pipeline_if_superresolution.py
│   │   │   ├── safety_checker.py
│   │   │   ├── timesteps.py
│   │   │   └── watermark.py
│   │   ├── dit/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_dit.py
│   │   ├── latent_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_latent_diffusion.py
│   │   │   └── pipeline_latent_diffusion_superresolution.py
│   │   ├── latent_diffusion_uncond/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_latent_diffusion_uncond.py
│   │   ├── onnx_utils.py
│   │   ├── paint_by_example/
│   │   │   ├── __init__.py
│   │   │   ├── image_encoder.py
│   │   │   └── pipeline_paint_by_example.py
│   │   ├── pipeline_flax_utils.py
│   │   ├── pipeline_utils.py
│   │   ├── pndm/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_pndm.py
│   │   ├── repaint/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_repaint.py
│   │   ├── score_sde_ve/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_score_sde_ve.py
│   │   ├── semantic_stable_diffusion/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_semantic_stable_diffusion.py
│   │   ├── spectrogram_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── continous_encoder.py
│   │   │   ├── midi_utils.py
│   │   │   ├── notes_encoder.py
│   │   │   └── pipeline_spectrogram_diffusion.py
│   │   ├── stable_diffusion/
│   │   │   ├── README.md
│   │   │   ├── __init__.py
│   │   │   ├── convert_from_ckpt.py
│   │   │   ├── pipeline_cycle_diffusion.py
│   │   │   ├── pipeline_flax_stable_diffusion.py
│   │   │   ├── pipeline_flax_stable_diffusion_controlnet.py
│   │   │   ├── pipeline_flax_stable_diffusion_img2img.py
│   │   │   ├── pipeline_flax_stable_diffusion_inpaint.py
│   │   │   ├── pipeline_onnx_stable_diffusion.py
│   │   │   ├── pipeline_onnx_stable_diffusion_img2img.py
│   │   │   ├── pipeline_onnx_stable_diffusion_inpaint.py
│   │   │   ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py
│   │   │   ├── pipeline_onnx_stable_diffusion_upscale.py
│   │   │   ├── pipeline_stable_diffusion.py
│   │   │   ├── pipeline_stable_diffusion_attend_and_excite.py
│   │   │   ├── pipeline_stable_diffusion_controlnet.py
│   │   │   ├── pipeline_stable_diffusion_depth2img.py
│   │   │   ├── pipeline_stable_diffusion_diffedit.py
│   │   │   ├── pipeline_stable_diffusion_image_variation.py
│   │   │   ├── pipeline_stable_diffusion_img2img.py
│   │   │   ├── pipeline_stable_diffusion_inpaint.py
│   │   │   ├── pipeline_stable_diffusion_inpaint_legacy.py
│   │   │   ├── pipeline_stable_diffusion_instruct_pix2pix.py
│   │   │   ├── pipeline_stable_diffusion_k_diffusion.py
│   │   │   ├── pipeline_stable_diffusion_latent_upscale.py
│   │   │   ├── pipeline_stable_diffusion_model_editing.py
│   │   │   ├── pipeline_stable_diffusion_panorama.py
│   │   │   ├── pipeline_stable_diffusion_pix2pix_zero.py
│   │   │   ├── pipeline_stable_diffusion_sag.py
│   │   │   ├── pipeline_stable_diffusion_upscale.py
│   │   │   ├── pipeline_stable_unclip.py
│   │   │   ├── pipeline_stable_unclip_img2img.py
│   │   │   ├── safety_checker.py
│   │   │   ├── safety_checker_flax.py
│   │   │   └── stable_unclip_image_normalizer.py
│   │   ├── stable_diffusion_safe/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_stable_diffusion_safe.py
│   │   │   └── safety_checker.py
│   │   ├── stochastic_karras_ve/
│   │   │   ├── __init__.py
│   │   │   └── pipeline_stochastic_karras_ve.py
│   │   ├── text_to_video_synthesis/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_text_to_video_synth.py
│   │   │   └── pipeline_text_to_video_zero.py
│   │   ├── unclip/
│   │   │   ├── __init__.py
│   │   │   ├── pipeline_unclip.py
│   │   │   ├── pipeline_unclip_image_variation.py
│   │   │   └── text_proj.py
│   │   ├── versatile_diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── modeling_text_unet.py
│   │   │   ├── pipeline_versatile_diffusion.py
│   │   │   ├── pipeline_versatile_diffusion_dual_guided.py
│   │   │   ├── pipeline_versatile_diffusion_image_variation.py
│   │   │   └── pipeline_versatile_diffusion_text_to_image.py
│   │   └── vq_diffusion/
│   │       ├── __init__.py
│   │       └── pipeline_vq_diffusion.py
│   ├── schedulers/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── scheduling_ddim.py
│   │   ├── scheduling_ddim_flax.py
│   │   ├── scheduling_ddim_inverse.py
│   │   ├── scheduling_ddpm.py
│   │   ├── scheduling_ddpm_flax.py
│   │   ├── scheduling_deis_multistep.py
│   │   ├── scheduling_dpmsolver_multistep.py
│   │   ├── scheduling_dpmsolver_multistep_flax.py
│   │   ├── scheduling_dpmsolver_multistep_inverse.py
│   │   ├── scheduling_dpmsolver_sde.py
│   │   ├── scheduling_dpmsolver_singlestep.py
│   │   ├── scheduling_euler_ancestral_discrete.py
│   │   ├── scheduling_euler_discrete.py
│   │   ├── scheduling_heun_discrete.py
│   │   ├── scheduling_ipndm.py
│   │   ├── scheduling_k_dpm_2_ancestral_discrete.py
│   │   ├── scheduling_k_dpm_2_discrete.py
│   │   ├── scheduling_karras_ve.py
│   │   ├── scheduling_karras_ve_flax.py
│   │   ├── scheduling_lms_discrete.py
│   │   ├── scheduling_lms_discrete_flax.py
│   │   ├── scheduling_pndm.py
│   │   ├── scheduling_pndm_flax.py
│   │   ├── scheduling_repaint.py
│   │   ├── scheduling_sde_ve.py
│   │   ├── scheduling_sde_ve_flax.py
│   │   ├── scheduling_sde_vp.py
│   │   ├── scheduling_unclip.py
│   │   ├── scheduling_unipc_multistep.py
│   │   ├── scheduling_utils.py
│   │   ├── scheduling_utils_flax.py
│   │   └── scheduling_vq_diffusion.py
│   ├── training_utils.py
│   └── utils/
│       ├── __init__.py
│       ├── accelerate_utils.py
│       ├── constants.py
│       ├── deprecation_utils.py
│       ├── doc_utils.py
│       ├── dummy_flax_and_transformers_objects.py
│       ├── dummy_flax_objects.py
│       ├── dummy_note_seq_objects.py
│       ├── dummy_onnx_objects.py
│       ├── dummy_pt_objects.py
│       ├── dummy_torch_and_librosa_objects.py
│       ├── dummy_torch_and_scipy_objects.py
│       ├── dummy_torch_and_torchsde_objects.py
│       ├── dummy_torch_and_transformers_and_k_diffusion_objects.py
│       ├── dummy_torch_and_transformers_and_onnx_objects.py
│       ├── dummy_torch_and_transformers_objects.py
│       ├── dummy_transformers_and_torch_and_note_seq_objects.py
│       ├── dynamic_modules_utils.py
│       ├── hub_utils.py
│       ├── import_utils.py
│       ├── logging.py
│       ├── model_card_template.md
│       ├── outputs.py
│       ├── pil_utils.py
│       ├── testing_utils.py
│       └── torch_utils.py
├── fid_score.py
├── inception.py
├── ldm_exp/
│   ├── LICENSE
│   ├── README.md
│   ├── configs/
│   │   ├── autoencoder/
│   │   │   ├── autoencoder_kl_16x16x16.yaml
│   │   │   ├── autoencoder_kl_32x32x4.yaml
│   │   │   ├── autoencoder_kl_64x64x3.yaml
│   │   │   └── autoencoder_kl_8x8x64.yaml
│   │   ├── latent-diffusion/
│   │   │   ├── celebahq-ldm-vq-4.yaml
│   │   │   ├── cin-ldm-vq-f8.yaml
│   │   │   ├── cin256-v2.yaml
│   │   │   ├── ffhq-ldm-vq-4.yaml
│   │   │   ├── lsun_bedrooms-ldm-vq-4.yaml
│   │   │   ├── lsun_churches-ldm-kl-8.yaml
│   │   │   └── txt2img-1p4B-eval.yaml
│   │   └── retrieval-augmented-diffusion/
│   │       └── 768x768.yaml
│   ├── environment.yaml
│   ├── fid_score.py
│   ├── inception.py
│   ├── ldm/
│   │   ├── lr_scheduler.py
│   │   ├── models/
│   │   │   ├── autoencoder.py
│   │   │   └── diffusion/
│   │   │       ├── __init__.py
│   │   │       ├── classifier.py
│   │   │       ├── ddim.py
│   │   │       ├── ddpm.py
│   │   │       └── plms.py
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   ├── attention.py
│   │   │   ├── diffusionmodules/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── model.py
│   │   │   │   ├── openaimodel.py
│   │   │   │   └── util.py
│   │   │   ├── distributions/
│   │   │   │   ├── __init__.py
│   │   │   │   └── distributions.py
│   │   │   ├── ema.py
│   │   │   ├── encoders/
│   │   │   │   ├── __init__.py
│   │   │   │   └── modules.py
│   │   │   ├── image_degradation/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── bsrgan.py
│   │   │   │   ├── bsrgan_light.py
│   │   │   │   └── utils_image.py
│   │   │   ├── losses/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── contperceptual.py
│   │   │   │   └── vqperceptual.py
│   │   │   └── x_transformer.py
│   │   └── util.py
│   ├── main.py
│   ├── models/
│   │   ├── first_stage_models/
│   │   │   ├── kl-f16/
│   │   │   │   └── config.yaml
│   │   │   ├── kl-f32/
│   │   │   │   └── config.yaml
│   │   │   ├── kl-f4/
│   │   │   │   └── config.yaml
│   │   │   ├── kl-f8/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f16/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f4/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f4-noattn/
│   │   │   │   └── config.yaml
│   │   │   ├── vq-f8/
│   │   │   │   └── config.yaml
│   │   │   └── vq-f8-n256/
│   │   │       └── config.yaml
│   │   └── ldm/
│   │       ├── bsr_sr/
│   │       │   └── config.yaml
│   │       ├── celeba256/
│   │       │   └── config.yaml
│   │       ├── cin256/
│   │       │   └── config.yaml
│   │       ├── ffhq256/
│   │       │   └── config.yaml
│   │       ├── inpainting_big/
│   │       │   └── config.yaml
│   │       ├── layout2img-openimages256/
│   │       │   └── config.yaml
│   │       ├── lsun_beds256/
│   │       │   └── config.yaml
│   │       ├── lsun_churches256/
│   │       │   └── config.yaml
│   │       ├── semantic_synthesis256/
│   │       │   └── config.yaml
│   │       ├── semantic_synthesis512/
│   │       │   └── config.yaml
│   │       └── text2img256/
│   │           └── config.yaml
│   ├── notebook_helpers.py
│   ├── profile_ldm.py
│   ├── profile_ldm_pretrained.py
│   ├── profile_model.py
│   ├── prune_ldm.py
│   ├── prune_ldm_no_grad.py
│   ├── run.sh
│   ├── sample_for_FID.py
│   ├── sample_imagenet.py
│   ├── sample_pruned.py
│   ├── scripts/
│   │   ├── download_first_stages.sh
│   │   ├── download_models.sh
│   │   ├── inpaint.py
│   │   ├── knn2img.py
│   │   ├── latent_imagenet_diffusion.ipynb
│   │   ├── sample_diffusion.py
│   │   ├── train_searcher.py
│   │   └── txt2img.py
│   ├── setup.py
│   ├── test_criterion.py
│   └── test_diffusion.py
├── ldm_prune.py
├── requirements.txt
├── scripts/
│   ├── finetune_ddpm_cifar10.sh
│   ├── prune_ddpm_cifar10.sh
│   ├── prune_ddpm_ema_bedroom_random.sh
│   ├── prune_ddpm_ema_church_random.sh
│   ├── prune_ldm.sh
│   ├── sample_ddpm_cifar10_pretrained.sh
│   ├── sample_ddpm_cifar10_pretrained_distributed.sh
│   └── sample_ddpm_cifar10_pruned.sh
├── tools/
│   ├── convert_cifar10_ddpm_ema.sh
│   ├── convert_ddpm_original_checkpoint_to_diffusers_cifar10.py
│   ├── convert_ldm_original_checkpoint_to_diffusers.py
│   ├── ddpm_cifar10_config.json
│   ├── extract_cifar10.py
│   └── ldm_unet_config.json
└── utils.py
Download .txt
Showing preview only (315K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (3841 symbols across 270 files)

FILE: ddpm_exp/datasets/__init__.py
  class Crop (line 14) | class Crop(object):
    method __init__ (line 15) | def __init__(self, x1, x2, y1, y2):
    method __call__ (line 21) | def __call__(self, img):
    method __repr__ (line 24) | def __repr__(self):
  function get_dataset (line 30) | def get_dataset(args, config):
  function logit_transform (line 184) | def logit_transform(image, lam=1e-6):
  function data_transform (line 189) | def data_transform(config, X):
  function inverse_data_transform (line 206) | def inverse_data_transform(config, X):

FILE: ddpm_exp/datasets/celeba.py
  class CelebA (line 8) | class CelebA(VisionDataset):
    method __init__ (line 50) | def __init__(self, root,
    method _check_integrity (line 108) | def _check_integrity(self):
    method download (line 120) | def download(self):
    method __getitem__ (line 133) | def __getitem__(self, index):
    method __len__ (line 158) | def __len__(self):
    method extra_repr (line 161) | def extra_repr(self):

FILE: ddpm_exp/datasets/ffhq.py
  class FFHQ (line 8) | class FFHQ(Dataset):
    method __init__ (line 9) | def __init__(self, path, transform, resolution=8):
    method __len__ (line 28) | def __len__(self):
    method __getitem__ (line 31) | def __getitem__(self, index):

FILE: ddpm_exp/datasets/lsun.py
  class LSUNClass (line 11) | class LSUNClass(VisionDataset):
    method __init__ (line 12) | def __init__(self, root, transform=None, target_transform=None):
    method __getitem__ (line 38) | def __getitem__(self, index):
    method __len__ (line 57) | def __len__(self):
  class LSUN (line 61) | class LSUN(VisionDataset):
    method __init__ (line 75) | def __init__(self, root, classes="train", transform=None, target_trans...
    method _verify_classes (line 96) | def _verify_classes(self, classes):
    method __getitem__ (line 146) | def __getitem__(self, index):
    method __len__ (line 171) | def __len__(self):
    method extra_repr (line 174) | def extra_repr(self):

FILE: ddpm_exp/datasets/utils.py
  function gen_bar_updater (line 8) | def gen_bar_updater():
  function check_integrity (line 20) | def check_integrity(fpath, md5=None):
  function makedir_exist_ok (line 36) | def makedir_exist_ok(dirpath):
  function download_url (line 49) | def download_url(url, root, filename=None, md5=None):
  function list_dir (line 88) | def list_dir(root, prefix=False):
  function list_files (line 110) | def list_files(root, suffix, prefix=False):
  function download_file_from_google_drive (line 134) | def download_file_from_google_drive(file_id, root, filename=None, md5=No...
  function _get_confirm_token (line 169) | def _get_confirm_token(response):
  function _save_response_content (line 177) | def _save_response_content(response, destination, chunk_size=32768):

FILE: ddpm_exp/datasets/vision.py
  class VisionDataset (line 6) | class VisionDataset(data.Dataset):
    method __init__ (line 9) | def __init__(self, root, transforms=None, transform=None, target_trans...
    method __getitem__ (line 28) | def __getitem__(self, index):
    method __len__ (line 31) | def __len__(self):
    method __repr__ (line 34) | def __repr__(self):
    method _format_transform_repr (line 49) | def _format_transform_repr(self, transform, head):
    method extra_repr (line 54) | def extra_repr(self):
  class StandardTransform (line 58) | class StandardTransform(object):
    method __init__ (line 59) | def __init__(self, transform=None, target_transform=None):
    method __call__ (line 63) | def __call__(self, input, target):
    method _format_transform_repr (line 70) | def _format_transform_repr(self, transform, head):
    method __repr__ (line 75) | def __repr__(self):

FILE: ddpm_exp/fid_score.py
  function tqdm (line 49) | def tqdm(x):
  class ImagePathDataset (line 84) | class ImagePathDataset(torch.utils.data.Dataset):
    method __init__ (line 85) | def __init__(self, files, transforms=None):
    method __len__ (line 89) | def __len__(self):
    method __getitem__ (line 92) | def __getitem__(self, i):
  function get_activations (line 100) | def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
  function calculate_frechet_distance (line 182) | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
  function calculate_activation_statistics (line 239) | def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
  function compute_statistics_of_path (line 264) | def compute_statistics_of_path(path, model, batch_size, dims, device,
  function calculate_fid_given_paths (line 285) | def calculate_fid_given_paths(paths, batch_size, device, dims, num_worke...
  function save_fid_stats (line 304) | def save_fid_stats(paths, batch_size, device, dims, num_workers=1, num_s...
  function main (line 324) | def main():

FILE: ddpm_exp/finetune.py
  function parse_args_and_config (line 21) | def parse_args_and_config(accelerator):
  function dict2namespace (line 245) | def dict2namespace(config):
  function main (line 256) | def main():

FILE: ddpm_exp/finetune_simple.py
  function parse_args_and_config (line 22) | def parse_args_and_config():
  function dict2namespace (line 247) | def dict2namespace(config):
  function main (line 258) | def main():

FILE: ddpm_exp/functions/__init__.py
  function get_optimizer (line 4) | def get_optimizer(config, parameters):

FILE: ddpm_exp/functions/ckpt_util.py
  function download (line 40) | def download(url, local_path, chunk_size=1024):
  function md5_hash (line 52) | def md5_hash(path):
  function get_ckpt_path (line 58) | def get_ckpt_path(name, root=None, check=False):

FILE: ddpm_exp/functions/denoising.py
  function compute_alpha (line 4) | def compute_alpha(beta, t):
  function generalized_steps (line 10) | def generalized_steps(x, seq, model, b, **kwargs):
  function ddpm_steps (line 35) | def ddpm_steps(x, seq, model, b, **kwargs):

FILE: ddpm_exp/functions/losses.py
  function noise_estimation_loss (line 4) | def noise_estimation_loss(model,
  function noise_estimation_kd_loss (line 17) | def noise_estimation_kd_loss(model,

FILE: ddpm_exp/inception.py
  class InceptionV3 (line 16) | class InceptionV3(nn.Module):
    method __init__ (line 31) | def __init__(self,
    method forward (line 129) | def forward(self, inp):
  function _inception_v3 (line 166) | def _inception_v3(*args, **kwargs):
  function fid_inception_v3 (line 197) | def fid_inception_v3():
  class FIDInceptionA (line 224) | class FIDInceptionA(torchvision.models.inception.InceptionA):
    method __init__ (line 226) | def __init__(self, in_channels, pool_features):
    method forward (line 229) | def forward(self, x):
  class FIDInceptionC (line 249) | class FIDInceptionC(torchvision.models.inception.InceptionC):
    method __init__ (line 251) | def __init__(self, in_channels, channels_7x7):
    method forward (line 254) | def forward(self, x):
  class FIDInceptionE_1 (line 277) | class FIDInceptionE_1(torchvision.models.inception.InceptionE):
    method __init__ (line 279) | def __init__(self, in_channels):
    method forward (line 282) | def forward(self, x):
  class FIDInceptionE_2 (line 310) | class FIDInceptionE_2(torchvision.models.inception.InceptionE):
    method __init__ (line 312) | def __init__(self, in_channels):
    method forward (line 315) | def forward(self, x):

FILE: ddpm_exp/main.py
  function parse_args_and_config (line 17) | def parse_args_and_config():
  function dict2namespace (line 200) | def dict2namespace(config):
  function main (line 211) | def main():

FILE: ddpm_exp/models/diffusion.py
  function get_timestep_embedding (line 6) | def get_timestep_embedding(timesteps, embedding_dim):
  function nonlinearity (line 27) | def nonlinearity(x):
  function Normalize (line 32) | def Normalize(in_channels):
  class Upsample (line 36) | class Upsample(nn.Module):
    method __init__ (line 37) | def __init__(self, in_channels, with_conv):
    method forward (line 47) | def forward(self, x):
  class Downsample (line 55) | class Downsample(nn.Module):
    method __init__ (line 56) | def __init__(self, in_channels, with_conv):
    method forward (line 67) | def forward(self, x):
  class ResnetBlock (line 77) | class ResnetBlock(nn.Module):
    method __init__ (line 78) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
    method forward (line 115) | def forward(self, x, temb):
  class AttnBlock (line 137) | class AttnBlock(nn.Module):
    method __init__ (line 138) | def __init__(self, in_channels):
    method forward (line 164) | def forward(self, x):
  class Model (line 192) | class Model(nn.Module):
    method __init__ (line 193) | def __init__(self, config):
    method forward (line 301) | def forward(self, x, t):

FILE: ddpm_exp/models/ema.py
  class EMAHelper (line 4) | class EMAHelper(object):
    method __init__ (line 5) | def __init__(self, mu=0.999):
    method to (line 9) | def to(self, device=None) -> None:
    method copy_to (line 15) | def copy_to(self, parameters) -> None:
    method store (line 20) | def store(self, parameters) -> None:
    method restore (line 29) | def restore(self, parameters) -> None:
    method register (line 36) | def register(self, module):
    method update (line 41) | def update(self, module):
    method ema (line 49) | def ema(self, module):
    method state_dict (line 68) | def state_dict(self):
    method load_state_dict (line 71) | def load_state_dict(self, state_dict):

FILE: ddpm_exp/prune.py
  function parse_args_and_config (line 22) | def parse_args_and_config():
  function dict2namespace (line 164) | def dict2namespace(config):
  function main (line 175) | def main():

FILE: ddpm_exp/prune_kd.py
  function parse_args_and_config (line 21) | def parse_args_and_config():
  function dict2namespace (line 240) | def dict2namespace(config):
  function main (line 251) | def main():

FILE: ddpm_exp/prune_ssim.py
  function parse_args_and_config (line 22) | def parse_args_and_config():
  function dict2namespace (line 165) | def dict2namespace(config):
  function main (line 176) | def main():

FILE: ddpm_exp/prune_test.py
  function parse_args_and_config (line 22) | def parse_args_and_config():
  function dict2namespace (line 158) | def dict2namespace(config):
  function main (line 169) | def main():

FILE: ddpm_exp/runners/diffusion.py
  function torch2hwcuint8 (line 21) | def torch2hwcuint8(x, clip=False):
  function get_beta_schedule (line 28) | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffus...
  class Diffusion (line 61) | class Diffusion(object):
    method __init__ (line 62) | def __init__(self, args, config, device=None):
    method build_model (line 100) | def build_model(self):
    method train (line 197) | def train(self, kd=False):
    method sample (line 375) | def sample(self):
    method sample_fid (line 390) | def sample_fid(self, model):
    method sample_sequence (line 429) | def sample_sequence(self, model):
    method sample_interpolation (line 452) | def sample_interpolation(self, model):
    method sample_image (line 492) | def sample_image(self, x, model, last=True):
    method test (line 539) | def test(self):

FILE: ddpm_exp/runners/diffusion_simple.py
  function torch2hwcuint8 (line 19) | def torch2hwcuint8(x, clip=False):
  function get_beta_schedule (line 26) | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffus...
  class Diffusion (line 59) | class Diffusion(object):
    method __init__ (line 60) | def __init__(self, args, config, device=None):
    method build_model (line 98) | def build_model(self):
    method train (line 176) | def train(self):
    method sample (line 303) | def sample(self):
    method sample_fid (line 318) | def sample_fid(self, model):
    method sample_sequence (line 354) | def sample_sequence(self, model):
    method sample_interpolation (line 377) | def sample_interpolation(self, model):
    method sample_image (line 417) | def sample_image(self, x, model, last=True):
    method test (line 464) | def test(self):

FILE: ddpm_exp/torch_pruning/_helpers.py
  function is_scalar (line 8) | def is_scalar(x):
  class _FlattenIndexMapping (line 18) | class _FlattenIndexMapping(object):
    method __init__ (line 19) | def __init__(self, stride=1, reverse=False):
    method __call__ (line 23) | def __call__(self, idxs):
  class _ConcatIndexMapping (line 36) | class _ConcatIndexMapping(object):
    method __init__ (line 37) | def __init__(self, offset, reverse=False):
    method __call__ (line 41) | def __call__(self, idxs):
  class _SplitIndexMapping (line 54) | class _SplitIndexMapping(object):
    method __init__ (line 55) | def __init__(self, offset, reverse=False):
    method __call__ (line 59) | def __call__(self, idxs):
  class _GroupConvIndexMapping (line 71) | class _GroupConvIndexMapping(object):
    method __init__ (line 72) | def __init__(self, in_channels, out_channels, groups, reverse=False):
    method __call__ (line 78) | def __call__(self, idxs):
  class ScalarSum (line 89) | class ScalarSum:
    method __init__ (line 90) | def __init__(self):
    method update (line 93) | def update(self, metric_name, metric_value):
    method results (line 98) | def results(self):
    method reset (line 101) | def reset(self):
  class VectorSum (line 105) | class VectorSum:
    method __init__ (line 106) | def __init__(self):
    method update (line 109) | def update(self, metric_name, metric_value):
    method results (line 119) | def results(self):
    method reset (line 122) | def reset(self):

FILE: ddpm_exp/torch_pruning/dependency.py
  class Node (line 15) | class Node(object):
    method __init__ (line 19) | def __init__(self, module: nn.Module, grad_fn, name: str = None):
    method name (line 35) | def name(self):
    method add_input (line 44) | def add_input(self, node, allow_dumplicated=False):
    method add_output (line 52) | def add_output(self, node, allow_dumplicated=False):
    method __repr__ (line 59) | def __repr__(self):
    method __str__ (line 62) | def __str__(self):
    method details (line 65) | def details(self):
  class Edge (line 83) | class Edge():  # for readability
  class Dependency (line 87) | class Dependency(Edge):
    method __init__ (line 88) | def __init__(
    method __call__ (line 109) | def __call__(self, idxs: list):
    method __repr__ (line 117) | def __repr__(self):
    method __str__ (line 120) | def __str__(self):
    method is_triggered_by (line 128) | def is_triggered_by(self, pruning_fn):
    method __eq__ (line 131) | def __eq__(self, other):
    method __hash__ (line 139) | def __hash__(self):
  class Group (line 146) | class Group(object):
    method __init__ (line 153) | def __init__(self):
    method prune (line 157) | def prune(self, idxs=None, record_history=True):
    method add_dep (line 187) | def add_dep(self, dep, idxs):
    method __getitem__ (line 190) | def __getitem__(self, k):
    method items (line 194) | def items(self):
    method has_dep (line 197) | def has_dep(self, dep):
    method has_pruning_op (line 203) | def has_pruning_op(self, dep, idxs):
    method __len__ (line 213) | def __len__(self):
    method add_and_merge (line 216) | def add_and_merge(self, dep, idxs):
    method __str__ (line 223) | def __str__(self):
    method details (line 233) | def details(self):
    method exec (line 246) | def exec(self):
    method __call__ (line 250) | def __call__(self):
  class DependencyGraph (line 255) | class DependencyGraph(object):
    method __init__ (line 257) | def __init__(self):
    method pruning_history (line 278) | def pruning_history(self):
    method load_pruning_history (line 281) | def load_pruning_history(self, pruning_history):
    method build_dependency (line 295) | def build_dependency(
    method register_customized_layer (line 385) | def register_customized_layer(
    method check_pruning_group (line 400) | def check_pruning_group(self, group: Group) -> bool:
    method is_out_channel_pruning_fn (line 422) | def is_out_channel_pruning_fn(self, fn: typing.Callable) -> bool:
    method is_in_channel_pruning_fn (line 425) | def is_in_channel_pruning_fn(self, fn: typing.Callable) -> bool:
    method get_pruning_plan (line 428) | def get_pruning_plan(self, module: nn.Module, pruning_fn: typing.Calla...
    method get_pruning_group (line 433) | def get_pruning_group(
    method get_all_groups (line 498) | def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TOR...
    method get_pruner_of_module (line 529) | def get_pruner_of_module(self, module):
    method get_out_channels (line 535) | def get_out_channels(self, module_or_node):
    method get_in_channels (line 548) | def get_in_channels(self, module_or_node):
    method _infer_out_channels_recursively (line 561) | def _infer_out_channels_recursively(self, node: Node):
    method _infer_in_channels_recursively (line 584) | def _infer_in_channels_recursively(self, node: Node):
    method _build_dependency (line 602) | def _build_dependency(self, module2node):
    method _trace (line 631) | def _trace(self, model, example_inputs, forward_fn, output_transform):
    method _trace_computational_graph (line 707) | def _trace_computational_graph(self, module2node, grad_fn_root, gradfn...
    method update_index_mapping (line 813) | def update_index_mapping(self):
    method _init_shape_information (line 825) | def _init_shape_information(self):
    method _update_flatten_index_mapping (line 856) | def _update_flatten_index_mapping(self, fc_node: Node):
    method _update_reshape_index_mapping (line 883) | def _update_reshape_index_mapping(self, reshape_node: Node):
    method _update_concat_index_mapping (line 946) | def _update_concat_index_mapping(self, cat_node: Node):
    method _update_split_index_mapping (line 993) | def _update_split_index_mapping(self, split_node: Node):
    method infer_channels (line 1024) | def infer_channels(self, node_1, node_2):

FILE: ddpm_exp/torch_pruning/importance.py
  class Importance (line 11) | class Importance(abc.ABC):
    method __call__ (line 15) | def __call__(self, group)-> torch.Tensor:
  class MagnitudeImportance (line 18) | class MagnitudeImportance(Importance):
    method __init__ (line 19) | def __init__(self, p=2, group_reduction="mean", normalizer='mean'):
    method _normalize (line 24) | def _normalize(self, group_importance, normalizer):
    method _reduce (line 42) | def _reduce(self, group_imp):
    method __call__ (line 60) | def __call__(self, group, ch_groups=1):
  class BNScaleImportance (line 129) | class BNScaleImportance(MagnitudeImportance):
    method __init__ (line 133) | def __init__(self, group_reduction='mean', normalizer='mean'):
    method __call__ (line 136) | def __call__(self, group, ch_groups=1):
  class LAMPImportance (line 154) | class LAMPImportance(MagnitudeImportance):
    method __init__ (line 158) | def __init__(self, p=2, group_reduction="mean", normalizer='mean'):
    method __call__ (line 162) | def __call__(self, group, **kwargs):
    method lamp (line 211) | def lamp(self, imp):
  class RandomImportance (line 221) | class RandomImportance(Importance):
    method __call__ (line 223) | def __call__(self, group, **kwargs):
  class GroupNormImportance (line 227) | class GroupNormImportance(MagnitudeImportance):
    method __init__ (line 228) | def __init__(self, p=2, normalizer='max'):
    method __call__ (line 234) | def __call__(self, group, ch_groups=1):
  class TaylorImportance (line 332) | class TaylorImportance(Importance):
    method __init__ (line 333) | def __init__(self, group_reduction="mean", normalizer='mean'):
    method set_model (line 337) | def set_model(self, model):
    method _normalize (line 340) | def _normalize(self, group_importance, normalizer):
    method _reduce (line 358) | def _reduce(self, group_imp):
    method __call__ (line 376) | def __call__(self, group, ch_groups=1):
  class FullTaylorImportance (line 438) | class FullTaylorImportance(Importance):
    method __init__ (line 439) | def __init__(self, order=1, group_reduction="mean", normalizer='mean'):
    method set_model (line 444) | def set_model(self, model):
    method _normalize (line 447) | def _normalize(self, group_importance, normalizer):
    method _reduce (line 465) | def _reduce(self, group_imp):
    method __call__ (line 483) | def __call__(self, group, ch_groups=1):
  class AbsTaylorImportance (line 553) | class AbsTaylorImportance(Importance):
    method __init__ (line 554) | def __init__(self, order=1, group_reduction="mean", normalizer='mean'):
    method set_model (line 560) | def set_model(self, model):
    method _normalize (line 563) | def _normalize(self, group_importance, normalizer):
    method _reduce (line 581) | def _reduce(self, group_imp):
    method accum_abs_grad (line 598) | def accum_abs_grad(self, model):
    method assign_abs_grad (line 606) | def assign_abs_grad(self, model):
    method __call__ (line 612) | def __call__(self, group, ch_groups=1):
  class FisherImportance (line 672) | class FisherImportance(Importance):
    method __init__ (line 673) | def __init__(self, group_reduction="mean", normalizer='mean'):
    method set_model (line 677) | def set_model(self, model):
    method _normalize (line 680) | def _normalize(self, group_importance, normalizer):
    method _reduce (line 698) | def _reduce(self, group_imp):
    method __call__ (line 716) | def __call__(self, group, ch_groups=1):

FILE: ddpm_exp/torch_pruning/ops.py
  class DummyMHA (line 5) | class DummyMHA(nn.Module):
    method __init__ (line 6) | def __init__(self):
  class _CustomizedOp (line 10) | class _CustomizedOp(nn.Module):
    method __init__ (line 11) | def __init__(self, op_class):
    method __repr__ (line 14) | def __repr__(self):
  class _ConcatOp (line 18) | class _ConcatOp(nn.Module):
    method __init__ (line 19) | def __init__(self, id):
    method __repr__ (line 25) | def __repr__(self):
  class _SplitOp (line 29) | class _SplitOp(nn.Module):
    method __init__ (line 30) | def __init__(self, id):
    method __repr__ (line 36) | def __repr__(self):
  class _ReshapeOp (line 39) | class _ReshapeOp(nn.Module):
    method __init__ (line 40) | def __init__(self, id):
    method __repr__ (line 43) | def __repr__(self):
  class _ElementWiseOp (line 47) | class _ElementWiseOp(nn.Module):
    method __init__ (line 48) | def __init__(self, id, grad_fn):
    method __repr__ (line 52) | def __repr__(self):
  class DummyPruner (line 58) | class DummyPruner(object):
    method __call__ (line 59) | def __call__(self, layer, *args, **kargs):
    method prune_out_channels (line 62) | def prune_out_channels(self, layer, idxs):
    method get_out_channels (line 67) | def get_out_channels(self, layer):
    method get_in_channels (line 70) | def get_in_channels(self, layer):
  class ConcatPruner (line 74) | class ConcatPruner(DummyPruner):
    method prune_out_channels (line 75) | def prune_out_channels(self, layer, idxs):
  class SplitPruner (line 98) | class SplitPruner(DummyPruner):
    method prune_out_channels (line 99) | def prune_out_channels(self, layer, idxs):
  class ReshapePruner (line 125) | class ReshapePruner(DummyPruner):
  class ElementWisePruner (line 128) | class ElementWisePruner(DummyPruner):
  class OPTYPE (line 150) | class OPTYPE(IntEnum):
  function module2type (line 170) | def module2type(module):
  function type2class (line 208) | def type2class(op_type):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
  class BNScalePruner (line 8) | class BNScalePruner(MetaPruner):
    method __init__ (line 9) | def __init__(
    method regularize (line 45) | def regularize(self, model):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/group_norm_pruner.py
  class GroupNormPruner (line 9) | class GroupNormPruner(MetaPruner):
    method __init__ (line 10) | def __init__(
    method regularize (line 55) | def regularize(self, model, base=16):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/magnitude_based_pruner.py
  class MagnitudePruner (line 3) | class MagnitudePruner(MetaPruner):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/metapruner.py
  class MetaPruner (line 11) | class MetaPruner:
    method __init__ (line 34) | def __init__(
    method pruning_history (line 135) | def pruning_history(self):
    method load_pruning_history (line 138) | def load_pruning_history(self, pruning_history):
    method get_target_sparsity (line 141) | def get_target_sparsity(self, module):
    method reset (line 146) | def reset(self):
    method regularize (line 149) | def regularize(self, model, loss):
    method step (line 154) | def step(self, interactive=False):
    method estimate_importance (line 169) | def estimate_importance(self, group, ch_groups=1):
    method _check_sparsity (line 172) | def _check_sparsity(self, group):
    method get_channel_groups (line 196) | def get_channel_groups(self, group):
    method prune_local (line 205) | def prune_local(self):
    method prune_global (line 256) | def prune_global(self):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/scaling_factor_pruner.py
  class ScalingFactorPruner (line 11) | class ScalingFactorPruner(MetaPruner):
    method __init__ (line 12) | def __init__(
    method regularize (line 51) | def regularize(self, model):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/scheduler.py
  function linear_scheduler (line 2) | def linear_scheduler(ch_sparsity_dict, steps):

FILE: ddpm_exp/torch_pruning/pruner/algorithms/taylor_pruner.py
  class TaylorPruner (line 9) | class TaylorPruner(MetaPruner):
    method __init__ (line 10) | def __init__(
    method regularize (line 55) | def regularize(self, model, base=16):

FILE: ddpm_exp/torch_pruning/pruner/function.py
  class BasePruningFunc (line 41) | class BasePruningFunc(ABC):
    method __init__ (line 44) | def __init__(self, pruning_dim=1):
    method prune_out_channels (line 48) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]):
    method prune_in_channels (line 52) | def prune_in_channels(self, layer: nn.Module, idxs: Sequence[int]):
    method get_out_channels (line 56) | def get_out_channels(self, layer: nn.Module):
    method get_in_channels (line 60) | def get_in_channels(self, layer: nn.Module):
    method check (line 63) | def check(self, layer, idxs, to_output):
    method __call__ (line 75) | def __call__(self, layer: nn.Module, idxs: Sequence[int], to_output: b...
  class ConvPruner (line 85) | class ConvPruner(BasePruningFunc):
    method prune_out_channels (line 88) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) ->...
    method prune_in_channels (line 117) | def prune_in_channels(self, layer: nn.Module, idxs: Sequence[int]) -> ...
    method get_out_channels (line 142) | def get_out_channels(self, layer):
    method get_in_channels (line 145) | def get_in_channels(self, layer):
  class DepthwiseConvPruner (line 149) | class DepthwiseConvPruner(ConvPruner):
    method prune_out_channels (line 152) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) ->...
  class LinearPruner (line 168) | class LinearPruner(BasePruningFunc):
    method prune_out_channels (line 171) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) ->...
    method prune_in_channels (line 190) | def prune_in_channels(self, layer: nn.Module, idxs: Sequence[int]) -> ...
    method get_out_channels (line 203) | def get_out_channels(self, layer):
    method get_in_channels (line 206) | def get_in_channels(self, layer):
  class BatchnormPruner (line 210) | class BatchnormPruner(BasePruningFunc):
    method prune_out_channels (line 213) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) ->...
    method get_out_channels (line 229) | def get_out_channels(self, layer):
    method get_in_channels (line 232) | def get_in_channels(self, layer):
  class LayernormPruner (line 236) | class LayernormPruner(BasePruningFunc):
    method __init__ (line 239) | def __init__(self, metrcis=None, pruning_dim=-1):
    method check (line 243) | def check(self, layer, idxs):
    method prune_out_channels (line 246) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) ->...
    method get_out_channels (line 268) | def get_out_channels(self, layer):
    method get_in_channels (line 271) | def get_in_channels(self, layer):
  class GroupNormPruner (line 274) | class GroupNormPruner(BasePruningFunc):
    method prune_out_channels (line 275) | def prune_out_channels(self, layer: nn.GroupNorm, idxs: list) -> nn.Mo...
    method get_out_channels (line 298) | def get_out_channels(self, layer):
    method get_in_channels (line 301) | def get_in_channels(self, layer):
  class InstanceNormPruner (line 304) | class InstanceNormPruner(BasePruningFunc):
    method prune_out_channels (line 305) | def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) ->...
    method get_out_channels (line 317) | def get_out_channels(self, layer):
    method get_in_channels (line 320) | def get_in_channels(self, layer):
  class PReLUPruner (line 324) | class PReLUPruner(BasePruningFunc):
    method prune_out_channels (line 327) | def prune_out_channels(self, layer: nn.PReLU, idxs: list) -> nn.Module:
    method get_out_channels (line 341) | def get_out_channels(self, layer):
    method get_in_channels (line 347) | def get_in_channels(self, layer):
  class EmbeddingPruner (line 350) | class EmbeddingPruner(BasePruningFunc):
    method prune_out_channels (line 353) | def prune_out_channels(self, layer: nn.Embedding, idxs: list) -> nn.Mo...
    method get_out_channels (line 367) | def get_out_channels(self, layer):
    method get_in_channels (line 370) | def get_in_channels(self, layer):
  class LSTMPruner (line 373) | class LSTMPruner(BasePruningFunc):
    method prune_out_channels (line 376) | def prune_out_channels(self, layer: nn.LSTM, idxs: list) -> nn.Module:
    method prune_in_channels (line 405) | def prune_in_channels(self, layer: nn.LSTM, idxs: list):
    method get_out_channels (line 416) | def get_out_channels(self, layer):
    method get_in_channels (line 419) | def get_in_channels(self, layer):
  class ParameterPruner (line 423) | class ParameterPruner(BasePruningFunc):
    method __init__ (line 425) | def __init__(self, pruning_dim=-1):
    method prune_out_channels (line 428) | def prune_out_channels(self, tensor, idxs: list) -> nn.Module:
    method get_out_channels (line 437) | def get_out_channels(self, parameter):
    method get_in_channels (line 440) | def get_in_channels(self, parameter):
  class MultiheadAttentionPruner (line 444) | class MultiheadAttentionPruner(BasePruningFunc):
    method check (line 447) | def check(self, layer, idxs, to_output):
    method prune_out_channels (line 451) | def prune_out_channels(self, layer, idxs: list) -> nn.Module:
    method get_out_channels (line 513) | def get_out_channels(self, layer):
    method get_in_channels (line 516) | def get_in_channels(self, layer):

FILE: ddpm_exp/torch_pruning/utils/op_counter.py
  function count_ops_and_params (line 16) | def count_ops_and_params(model, example_inputs):
  function empty_flops_counter_hook (line 35) | def empty_flops_counter_hook(module, input, output):
  function upsample_flops_counter_hook (line 39) | def upsample_flops_counter_hook(module, input, output):
  function relu_flops_counter_hook (line 48) | def relu_flops_counter_hook(module, input, output):
  function linear_flops_counter_hook (line 53) | def linear_flops_counter_hook(module, input, output):
  function pool_flops_counter_hook (line 61) | def pool_flops_counter_hook(module, input, output):
  function bn_flops_counter_hook (line 66) | def bn_flops_counter_hook(module, input, output):
  function conv_flops_counter_hook (line 75) | def conv_flops_counter_hook(conv_module, input, output):
  function rnn_flops (line 106) | def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
  function rnn_flops_counter_hook (line 131) | def rnn_flops_counter_hook(rnn_module, input, output):
  function rnn_cell_flops_counter_hook (line 164) | def rnn_cell_flops_counter_hook(rnn_cell_module, input, output):
  function multihead_attention_counter_hook (line 181) | def multihead_attention_counter_hook(multihead_attention_module, input, ...
  function accumulate_flops (line 308) | def accumulate_flops(self):
  function get_model_parameters_number (line 318) | def get_model_parameters_number(model):
  function add_flops_counting_methods (line 323) | def add_flops_counting_methods(net_main_module):
  function compute_average_flops_cost (line 337) | def compute_average_flops_cost(self):
  function start_flops_count (line 357) | def start_flops_count(self, **kwargs):
  function stop_flops_count (line 393) | def stop_flops_count(self):
  function reset_flops_count (line 405) | def reset_flops_count(self):
  function batch_counter_hook (line 416) | def batch_counter_hook(module, input, output):
  function add_batch_counter_variables_or_reset (line 429) | def add_batch_counter_variables_or_reset(module):
  function add_batch_counter_hook_function (line 434) | def add_batch_counter_hook_function(module):
  function remove_batch_counter_hook_function (line 442) | def remove_batch_counter_hook_function(module):
  function add_flops_counter_variable_or_reset (line 448) | def add_flops_counter_variable_or_reset(module):
  function is_supported_instance (line 460) | def is_supported_instance(module):
  function remove_flops_counter_hook_function (line 466) | def remove_flops_counter_hook_function(module):
  function remove_flops_counter_variables (line 473) | def remove_flops_counter_variables(module):

FILE: ddpm_exp/torch_pruning/utils/utils.py
  function count_params (line 8) | def count_params(module):
  function flatten_as_list (line 11) | def flatten_as_list(obj):
  function draw_computational_graph (line 27) | def draw_computational_graph(DG, save_as, title='Computational Graph', f...
  function draw_groups (line 54) | def draw_groups(DG, save_as, title='Group', figsize=(16, 16), dpi=200, c...
  function draw_dependency_graph (line 95) | def draw_dependency_graph(DG, save_as, title='Group', figsize=(16, 16), ...

FILE: ddpm_exp/utils.py
  class UnlabeledImageFolder (line 5) | class UnlabeledImageFolder(torch.utils.data.Dataset):
    method __init__ (line 6) | def __init__(self, root, transform=None, exts=["*.jpg", "*.png", "*.jp...
    method __len__ (line 13) | def __len__(self):
    method __getitem__ (line 16) | def __getitem__(self, idx):
  function set_dropout (line 25) | def set_dropout(model, p):

FILE: ddpm_prune.py
  function reset_parameters (line 126) | def reset_parameters(model):

FILE: ddpm_train.py
  function parse_args (line 29) | def parse_args():
  function main (line 252) | def main(args):

FILE: diffusers/commands/__init__.py
  class BaseDiffusersCLICommand (line 19) | class BaseDiffusersCLICommand(ABC):
    method register_subcommand (line 22) | def register_subcommand(parser: ArgumentParser):
    method run (line 26) | def run(self):

FILE: diffusers/commands/diffusers_cli.py
  function main (line 21) | def main():

FILE: diffusers/commands/env.py
  function info_command_factory (line 25) | def info_command_factory(_):
  class EnvironmentCommand (line 29) | class EnvironmentCommand(BaseDiffusersCLICommand):
    method register_subcommand (line 31) | def register_subcommand(parser: ArgumentParser):
    method run (line 35) | def run(self):
    method format_dict (line 83) | def format_dict(d):

FILE: diffusers/configuration_utils.py
  class FrozenDict (line 50) | class FrozenDict(OrderedDict):
    method __init__ (line 51) | def __init__(self, *args, **kwargs):
    method __delitem__ (line 59) | def __delitem__(self, *args, **kwargs):
    method setdefault (line 62) | def setdefault(self, *args, **kwargs):
    method pop (line 65) | def pop(self, *args, **kwargs):
    method update (line 68) | def update(self, *args, **kwargs):
    method __setattr__ (line 71) | def __setattr__(self, name, value):
    method __setitem__ (line 76) | def __setitem__(self, name, value):
  class ConfigMixin (line 82) | class ConfigMixin:
    method register_to_config (line 105) | def register_to_config(self, **kwargs):
    method __getattr__ (line 122) | def __getattr__(self, name: str) -> Any:
    method save_config (line 140) | def save_config(self, save_directory: Union[str, os.PathLike], push_to...
    method from_config (line 161) | def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None,...
    method get_config_dict (line 244) | def get_config_dict(cls, *args, **kwargs):
    method load_config (line 253) | def load_config(
    method _get_init_keys (line 429) | def _get_init_keys(cls):
    method extract_init_dict (line 433) | def extract_init_dict(cls, config_dict, **kwargs):
    method _dict_from_json_file (line 517) | def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
    method __repr__ (line 522) | def __repr__(self):
    method config (line 526) | def config(self) -> Dict[str, Any]:
    method to_json_string (line 535) | def to_json_string(self) -> str:
    method to_json_file (line 559) | def to_json_file(self, json_file_path: Union[str, os.PathLike]):
  function register_to_config (line 571) | def register_to_config(init):
  function flax_register_to_config (line 616) | def flax_register_to_config(cls):

FILE: diffusers/dependency_versions_check.py
  function dep_version_check (line 46) | def dep_version_check(pkg, hint=None):

FILE: diffusers/experimental/rl/value_guided_sampling.py
  class ValueGuidedRLPipeline (line 25) | class ValueGuidedRLPipeline(DiffusionPipeline):
    method __init__ (line 42) | def __init__(
    method normalize (line 70) | def normalize(self, x_in, key):
    method de_normalize (line 73) | def de_normalize(self, x_in, key):
    method to_torch (line 76) | def to_torch(self, x_in):
    method reset_x0 (line 83) | def reset_x0(self, x_in, cond, act_dim):
    method run_diffusion (line 88) | def run_diffusion(self, x, conditions, n_guide_steps, scale):
    method __call__ (line 121) | def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_st...

FILE: diffusers/image_processor.py
  class VaeImageProcessor (line 27) | class VaeImageProcessor(ConfigMixin):
    method __init__ (line 46) | def __init__(
    method numpy_to_pil (line 56) | def numpy_to_pil(images):
    method numpy_to_pt (line 72) | def numpy_to_pt(images):
    method pt_to_numpy (line 83) | def pt_to_numpy(images):
    method normalize (line 91) | def normalize(images):
    method denormalize (line 98) | def denormalize(images):
    method resize (line 104) | def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
    method preprocess (line 113) | def preprocess(
    method postprocess (line 173) | def postprocess(

FILE: diffusers/loaders.py
  class AttnProcsLayers (line 66) | class AttnProcsLayers(torch.nn.Module):
    method __init__ (line 67) | def __init__(self, state_dict: Dict[str, torch.Tensor]):
  class UNet2DConditionLoadersMixin (line 108) | class UNet2DConditionLoadersMixin:
    method load_attn_procs (line 112) | def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union...
    method save_attn_procs (line 327) | def save_attn_procs(
  class TextualInversionLoaderMixin (line 404) | class TextualInversionLoaderMixin:
    method maybe_convert_prompt (line 409) | def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenize...
    method _maybe_convert_prompt (line 437) | def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTok...
    method load_textual_inversion (line 466) | def load_textual_inversion(
  class LoraLoaderMixin (line 745) | class LoraLoaderMixin:
    method load_lora_weights (line 759) | def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Uni...
    method text_encoder_lora_attn_procs (line 926) | def text_encoder_lora_attn_procs(self):
    method _modify_text_encoder (line 931) | def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProc...
    method _get_lora_layer_attribute (line 955) | def _get_lora_layer_attribute(self, name: str) -> str:
    method _load_text_encoder_attn_procs (line 965) | def _load_text_encoder_attn_procs(
    method save_lora_weights (line 1131) | def save_lora_weights(
  class FromCkptMixin (line 1212) | class FromCkptMixin:
    method from_ckpt (line 1217) | def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):

FILE: diffusers/models/attention.py
  class BasicTransformerBlock (line 26) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 47) | def __init__(
    method forward (line 121) | def forward(
  class FeedForward (line 185) | class FeedForward(nn.Module):
    method __init__ (line 198) | def __init__(
    method forward (line 231) | def forward(self, hidden_states):
  class GELU (line 237) | class GELU(nn.Module):
    method __init__ (line 242) | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
    method gelu (line 247) | def gelu(self, gate):
    method forward (line 253) | def forward(self, hidden_states):
  class GEGLU (line 259) | class GEGLU(nn.Module):
    method __init__ (line 268) | def __init__(self, dim_in: int, dim_out: int):
    method gelu (line 272) | def gelu(self, gate):
    method forward (line 278) | def forward(self, hidden_states):
  class ApproximateGELU (line 283) | class ApproximateGELU(nn.Module):
    method __init__ (line 290) | def __init__(self, dim_in: int, dim_out: int):
    method forward (line 294) | def forward(self, x):
  class AdaLayerNorm (line 299) | class AdaLayerNorm(nn.Module):
    method __init__ (line 304) | def __init__(self, embedding_dim, num_embeddings):
    method forward (line 311) | def forward(self, x, timestep):
  class AdaLayerNormZero (line 318) | class AdaLayerNormZero(nn.Module):
    method __init__ (line 323) | def __init__(self, embedding_dim, num_embeddings):
    method forward (line 332) | def forward(self, x, timestep, class_labels, hidden_dtype=None):
  class AdaGroupNorm (line 339) | class AdaGroupNorm(nn.Module):
    method __init__ (line 344) | def __init__(
    method forward (line 362) | def forward(self, x, emb):

FILE: diffusers/models/attention_flax.py
  function _query_chunk_attention (line 23) | def _query_chunk_attention(query, key, value, precision, key_chunk_size:...
  function jax_memory_efficient_attention (line 74) | def jax_memory_efficient_attention(
  class FlaxAttention (line 119) | class FlaxAttention(nn.Module):
    method setup (line 145) | def setup(self):
    method reshape_heads_to_batch_dim (line 156) | def reshape_heads_to_batch_dim(self, tensor):
    method reshape_batch_dim_to_heads (line 164) | def reshape_batch_dim_to_heads(self, tensor):
    method __call__ (line 172) | def __call__(self, hidden_states, context=None, deterministic=True):
  class FlaxBasicTransformerBlock (line 220) | class FlaxBasicTransformerBlock(nn.Module):
    method setup (line 250) | def setup(self):
    method __call__ (line 264) | def __call__(self, hidden_states, context, deterministic=True):
  class FlaxTransformer2DModel (line 286) | class FlaxTransformer2DModel(nn.Module):
    method setup (line 320) | def setup(self):
    method __call__ (line 359) | def __call__(self, hidden_states, context, deterministic=True):
  class FlaxFeedForward (line 384) | class FlaxFeedForward(nn.Module):
    method setup (line 405) | def setup(self):
    method __call__ (line 411) | def __call__(self, hidden_states, deterministic=True):
  class FlaxGEGLU (line 417) | class FlaxGEGLU(nn.Module):
    method setup (line 434) | def setup(self):
    method __call__ (line 438) | def __call__(self, hidden_states, deterministic=True):

FILE: diffusers/models/attention_processor.py
  class Attention (line 36) | class Attention(nn.Module):
    method __init__ (line 51) | def __init__(
    method set_use_memory_efficient_attention_xformers (line 159) | def set_use_memory_efficient_attention_xformers(
    method set_attention_slice (line 261) | def set_attention_slice(self, slice_size):
    method set_processor (line 282) | def set_processor(self, processor: "AttnProcessor"):
    method forward (line 295) | def forward(self, hidden_states, encoder_hidden_states=None, attention...
    method batch_to_head_dim (line 307) | def batch_to_head_dim(self, tensor):
    method head_to_batch_dim (line 314) | def head_to_batch_dim(self, tensor, out_dim=3):
    method get_attention_scores (line 325) | def get_attention_scores(self, query, key, attention_mask=None):
    method prepare_attention_mask (line 359) | def prepare_attention_mask(self, attention_mask, target_length, batch_...
    method norm_encoder_hidden_states (line 395) | def norm_encoder_hidden_states(self, encoder_hidden_states):
  class AttnProcessor (line 415) | class AttnProcessor:
    method __call__ (line 416) | def __call__(
  class LoRALinearLayer (line 473) | class LoRALinearLayer(nn.Module):
    method __init__ (line 474) | def __init__(self, in_features, out_features, rank=4):
    method forward (line 486) | def forward(self, hidden_states):
  class LoRAAttnProcessor (line 496) | class LoRAAttnProcessor(nn.Module):
    method __init__ (line 497) | def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
    method __call__ (line 509) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class CustomDiffusionAttnProcessor (line 560) | class CustomDiffusionAttnProcessor(nn.Module):
    method __init__ (line 561) | def __init__(
    method __call__ (line 587) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class AttnAddedKVProcessor (line 638) | class AttnAddedKVProcessor:
    method __call__ (line 639) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class AttnAddedKVProcessor2_0 (line 687) | class AttnAddedKVProcessor2_0:
    method __init__ (line 688) | def __init__(self):
    method __call__ (line 694) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class LoRAAttnAddedKVProcessor (line 745) | class LoRAAttnAddedKVProcessor(nn.Module):
    method __init__ (line 746) | def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
    method __call__ (line 760) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class XFormersAttnProcessor (line 812) | class XFormersAttnProcessor:
    method __init__ (line 813) | def __init__(self, attention_op: Optional[Callable] = None):
    method __call__ (line 816) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class AttnProcessor2_0 (line 870) | class AttnProcessor2_0:
    method __init__ (line 871) | def __init__(self):
    method __call__ (line 875) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class LoRAXFormersAttnProcessor (line 938) | class LoRAXFormersAttnProcessor(nn.Module):
    method __init__ (line 939) | def __init__(self, hidden_size, cross_attention_dim, rank=4, attention...
    method __call__ (line 952) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class CustomDiffusionXFormersAttnProcessor (line 1004) | class CustomDiffusionXFormersAttnProcessor(nn.Module):
    method __init__ (line 1005) | def __init__(
    method __call__ (line 1033) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class SlicedAttnProcessor (line 1089) | class SlicedAttnProcessor:
    method __init__ (line 1090) | def __init__(self, slice_size):
    method __call__ (line 1093) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class SlicedAttnAddedKVProcessor (line 1161) | class SlicedAttnAddedKVProcessor:
    method __init__ (line 1162) | def __init__(self, slice_size):
    method __call__ (line 1165) | def __call__(self, attn: "Attention", hidden_states, encoder_hidden_st...

FILE: diffusers/models/autoencoder_kl.py
  class AutoencoderKLOutput (line 27) | class AutoencoderKLOutput(BaseOutput):
  class AutoencoderKL (line 40) | class AutoencoderKL(ModelMixin, ConfigMixin):
    method __init__ (line 71) | def __init__(
    method _set_gradient_checkpointing (line 126) | def _set_gradient_checkpointing(self, module, value=False):
    method enable_tiling (line 130) | def enable_tiling(self, use_tiling: bool = True):
    method disable_tiling (line 138) | def disable_tiling(self):
    method enable_slicing (line 145) | def enable_slicing(self):
    method disable_slicing (line 152) | def disable_slicing(self):
    method encode (line 160) | def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Au...
    method _decode (line 173) | def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> U...
    method decode (line 186) | def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Un...
    method blend_v (line 198) | def blend_v(self, a, b, blend_extent):
    method blend_h (line 204) | def blend_h(self, a, b, blend_extent):
    method tiled_encode (line 210) | def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True)...
    method tiled_decode (line 257) | def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True)...
    method forward (line 304) | def forward(

FILE: diffusers/models/controlnet.py
  class ControlNetOutput (line 39) | class ControlNetOutput(BaseOutput):
  class ControlNetConditioningEmbedding (line 44) | class ControlNetConditioningEmbedding(nn.Module):
    method __init__ (line 54) | def __init__(
    method forward (line 76) | def forward(self, conditioning):
  class ControlNetModel (line 89) | class ControlNetModel(ModelMixin, ConfigMixin):
    method __init__ (line 93) | def __init__(
    method from_unet (line 263) | def from_unet(
    method attn_processors (line 318) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 342) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
    method set_default_attn_processor (line 373) | def set_default_attn_processor(self):
    method set_attention_slice (line 380) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 445) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 449) | def forward(
  function zero_module (line 585) | def zero_module(module):

FILE: diffusers/models/controlnet_flax.py
  class FlaxControlNetOutput (line 34) | class FlaxControlNetOutput(BaseOutput):
  class FlaxControlNetConditioningEmbedding (line 39) | class FlaxControlNetConditioningEmbedding(nn.Module):
    method setup (line 44) | def setup(self):
    method __call__ (line 82) | def __call__(self, conditioning):
  class FlaxControlNetModel (line 96) | class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
    method init_weights (line 167) | def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
    method setup (line 181) | def setup(self):
    method __call__ (line 302) | def __call__(

FILE: diffusers/models/cross_attention.py
  class CrossAttention (line 41) | class CrossAttention(Attention):
    method __init__ (line 42) | def __init__(self, *args, **kwargs):
  class CrossAttnProcessor (line 48) | class CrossAttnProcessor(AttnProcessorRename):
    method __init__ (line 49) | def __init__(self, *args, **kwargs):
  class LoRACrossAttnProcessor (line 55) | class LoRACrossAttnProcessor(LoRAAttnProcessor):
    method __init__ (line 56) | def __init__(self, *args, **kwargs):
  class CrossAttnAddedKVProcessor (line 62) | class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
    method __init__ (line 63) | def __init__(self, *args, **kwargs):
  class XFormersCrossAttnProcessor (line 69) | class XFormersCrossAttnProcessor(XFormersAttnProcessor):
    method __init__ (line 70) | def __init__(self, *args, **kwargs):
  class LoRAXFormersCrossAttnProcessor (line 76) | class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
    method __init__ (line 77) | def __init__(self, *args, **kwargs):
  class SlicedCrossAttnProcessor (line 83) | class SlicedCrossAttnProcessor(SlicedAttnProcessor):
    method __init__ (line 84) | def __init__(self, *args, **kwargs):
  class SlicedCrossAttnAddedKVProcessor (line 90) | class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
    method __init__ (line 91) | def __init__(self, *args, **kwargs):

FILE: diffusers/models/dual_transformer_2d.py
  class DualTransformer2DModel (line 21) | class DualTransformer2DModel(nn.Module):
    method __init__ (line 48) | def __init__(
    method forward (line 97) | def forward(

FILE: diffusers/models/embeddings.py
  function get_timestep_embedding (line 22) | def get_timestep_embedding(
  function get_2d_sincos_pos_embed (line 65) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra...
  function get_2d_sincos_pos_embed_from_grid (line 82) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  function get_1d_sincos_pos_embed_from_grid (line 94) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  class PatchEmbed (line 115) | class PatchEmbed(nn.Module):
    method __init__ (line 118) | def __init__(
    method forward (line 146) | def forward(self, latent):
  class TimestepEmbedding (line 155) | class TimestepEmbedding(nn.Module):
    method __init__ (line 156) | def __init__(
    method forward (line 200) | def forward(self, sample, condition=None):
  class Timesteps (line 215) | class Timesteps(nn.Module):
    method __init__ (line 216) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
    method forward (line 222) | def forward(self, timesteps):
  class GaussianFourierProjection (line 232) | class GaussianFourierProjection(nn.Module):
    method __init__ (line 235) | def __init__(
    method forward (line 249) | def forward(self, x):
  class ImagePositionalEmbeddings (line 262) | class ImagePositionalEmbeddings(nn.Module):
    method __init__ (line 286) | def __init__(
    method forward (line 304) | def forward(self, index):
  class LabelEmbedding (line 327) | class LabelEmbedding(nn.Module):
    method __init__ (line 337) | def __init__(self, num_classes, hidden_size, dropout_prob):
    method token_drop (line 344) | def token_drop(self, labels, force_drop_ids=None):
    method forward (line 355) | def forward(self, labels, force_drop_ids=None):
  class CombinedTimestepLabelEmbeddings (line 363) | class CombinedTimestepLabelEmbeddings(nn.Module):
    method __init__ (line 364) | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
    method forward (line 371) | def forward(self, timestep, class_labels, hidden_dtype=None):
  class TextTimeEmbedding (line 382) | class TextTimeEmbedding(nn.Module):
    method __init__ (line 383) | def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: i...
    method forward (line 390) | def forward(self, hidden_states):
  class AttentionPooling (line 398) | class AttentionPooling(nn.Module):
    method __init__ (line 401) | def __init__(self, num_heads, embed_dim, dtype=None):
    method forward (line 411) | def forward(self, x):

FILE: diffusers/models/embeddings_flax.py
  function get_sinusoidal_embeddings (line 20) | def get_sinusoidal_embeddings(
  class FlaxTimestepEmbedding (line 58) | class FlaxTimestepEmbedding(nn.Module):
    method __call__ (line 72) | def __call__(self, temb):
  class FlaxTimesteps (line 79) | class FlaxTimesteps(nn.Module):
    method __call__ (line 92) | def __call__(self, timesteps):

FILE: diffusers/models/modeling_flax_pytorch_utils.py
  function rename_key (line 28) | def rename_key(key):
  function rename_key_and_reshape_tensor (line 43) | def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_s...
  function convert_pytorch_state_dict_to_flax (line 90) | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_k...

FILE: diffusers/models/modeling_flax_utils.py
  class FlaxModelMixin (line 45) | class FlaxModelMixin:
    method _from_config (line 57) | def _from_config(cls, config, **kwargs):
    method _cast_floating_to (line 63) | def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jn...
    method to_bf16 (line 87) | def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
    method to_fp32 (line 126) | def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
    method to_fp16 (line 153) | def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
    method init_weights (line 192) | def init_weights(self, rng: jax.random.KeyArray) -> Dict:
    method from_pretrained (line 196) | def from_pretrained(
    method save_pretrained (line 487) | def save_pretrained(

FILE: diffusers/models/modeling_pytorch_flax_utils.py
  function load_flax_checkpoint_in_pytorch_model (line 37) | def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
  function load_flax_weights_in_pytorch_model (line 58) | def load_flax_weights_in_pytorch_model(pt_model, flax_state):

FILE: diffusers/models/modeling_utils.py
  function get_parameter_device (line 62) | def get_parameter_device(parameter: torch.nn.Module):
  function get_parameter_dtype (line 78) | def get_parameter_dtype(parameter: torch.nn.Module):
  function load_state_dict (line 100) | def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: O...
  function _load_state_dict_into_model (line 131) | def _load_state_dict_into_model(model_to_load, state_dict):
  class ModelMixin (line 152) | class ModelMixin(torch.nn.Module):
    method __init__ (line 166) | def __init__(self):
    method __getattr__ (line 169) | def __getattr__(self, name: str) -> Any:
    method is_gradient_checkpointing (line 188) | def is_gradient_checkpointing(self) -> bool:
    method enable_gradient_checkpointing (line 197) | def enable_gradient_checkpointing(self):
    method disable_gradient_checkpointing (line 208) | def disable_gradient_checkpointing(self):
    method set_use_memory_efficient_attention_xformers (line 218) | def set_use_memory_efficient_attention_xformers(
    method enable_xformers_memory_efficient_attention (line 235) | def enable_xformers_memory_efficient_attention(self, attention_op: Opt...
    method disable_xformers_memory_efficient_attention (line 267) | def disable_xformers_memory_efficient_attention(self):
    method save_pretrained (line 273) | def save_pretrained(
    method from_pretrained (line 334) | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union...
    method _load_pretrained_model (line 663) | def _load_pretrained_model(
    method device (line 767) | def device(self) -> device:
    method dtype (line 775) | def dtype(self) -> torch.dtype:
    method num_parameters (line 781) | def num_parameters(self, only_trainable: bool = False, exclude_embeddi...
    method _convert_deprecated_attention_blocks (line 809) | def _convert_deprecated_attention_blocks(self, state_dict):

FILE: diffusers/models/prior_transformer.py
  class PriorTransformerOutput (line 16) | class PriorTransformerOutput(BaseOutput):
  class PriorTransformer (line 26) | class PriorTransformer(ModelMixin, ConfigMixin):
    method __init__ (line 52) | def __init__(
    method forward (line 107) | def forward(
    method post_process_latents (line 192) | def post_process_latents(self, prior_latents):

FILE: diffusers/models/resnet.py
  class Upsample1D (line 26) | class Upsample1D(nn.Module):
    method __init__ (line 40) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
    method forward (line 54) | def forward(self, x):
  class Downsample1D (line 67) | class Downsample1D(nn.Module):
    method __init__ (line 81) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
    method forward (line 96) | def forward(self, x):
  class Upsample2D (line 101) | class Upsample2D(nn.Module):
    method __init__ (line 115) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
    method forward (line 135) | def forward(self, hidden_states, output_size=None):
  class Downsample2D (line 173) | class Downsample2D(nn.Module):
    method __init__ (line 187) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
    method forward (line 211) | def forward(self, hidden_states):
  class FirUpsample2D (line 223) | class FirUpsample2D(nn.Module):
    method __init__ (line 237) | def __init__(self, channels=None, out_channels=None, use_conv=False, f...
    method _upsample_2d (line 246) | def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor...
    method forward (line 326) | def forward(self, hidden_states):
  class FirDownsample2D (line 336) | class FirDownsample2D(nn.Module):
    method __init__ (line 350) | def __init__(self, channels=None, out_channels=None, use_conv=False, f...
    method _downsample_2d (line 359) | def _downsample_2d(self, hidden_states, weight=None, kernel=None, fact...
    method forward (line 413) | def forward(self, hidden_states):
  class KDownsample2D (line 424) | class KDownsample2D(nn.Module):
    method __init__ (line 425) | def __init__(self, pad_mode="reflect"):
    method forward (line 432) | def forward(self, x):
  class KUpsample2D (line 440) | class KUpsample2D(nn.Module):
    method __init__ (line 441) | def __init__(self, pad_mode="reflect"):
    method forward (line 448) | def forward(self, x):
  class ResnetBlock2D (line 456) | class ResnetBlock2D(nn.Module):
    method __init__ (line 487) | def __init__(
    method forward (line 589) | def forward(self, input_tensor, temb):
  class Mish (line 642) | class Mish(torch.nn.Module):
    method forward (line 643) | def forward(self, hidden_states):
  function rearrange_dims (line 648) | def rearrange_dims(tensor):
  class Conv1dBlock (line 659) | class Conv1dBlock(nn.Module):
    method __init__ (line 664) | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
    method forward (line 671) | def forward(self, x):
  class ResidualTemporalBlock1D (line 681) | class ResidualTemporalBlock1D(nn.Module):
    method __init__ (line 682) | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
    method forward (line 694) | def forward(self, x, t):
  function upsample_2d (line 710) | def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
  function downsample_2d (line 747) | def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
  function upfirdn2d_native (line 782) | def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
  class TemporalConvLayer (line 826) | class TemporalConvLayer(nn.Module):
    method __init__ (line 832) | def __init__(self, in_dim, out_dim=None, dropout=0.0):
    method forward (line 865) | def forward(self, hidden_states, num_frames=1):

FILE: diffusers/models/resnet_flax.py
  class FlaxUpsample2D (line 19) | class FlaxUpsample2D(nn.Module):
    method setup (line 23) | def setup(self):
    method __call__ (line 32) | def __call__(self, hidden_states):
  class FlaxDownsample2D (line 43) | class FlaxDownsample2D(nn.Module):
    method setup (line 47) | def setup(self):
    method __call__ (line 56) | def __call__(self, hidden_states):
  class FlaxResnetBlock2D (line 63) | class FlaxResnetBlock2D(nn.Module):
    method setup (line 70) | def setup(self):
    method __call__ (line 106) | def __call__(self, hidden_states, temb, deterministic=True):

FILE: diffusers/models/t5_film_transformer.py
  class T5FilmDecoder (line 25) | class T5FilmDecoder(ModelMixin, ConfigMixin):
    method __init__ (line 27) | def __init__(
    method encoder_decoder_mask (line 66) | def encoder_decoder_mask(self, query_input, key_input):
    method forward (line 70) | def forward(self, encodings_and_masks, decoder_input_tokens, decoder_n...
  class DecoderLayer (line 127) | class DecoderLayer(nn.Module):
    method __init__ (line 128) | def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer...
    method forward (line 153) | def forward(
  class T5LayerSelfAttentionCond (line 185) | class T5LayerSelfAttentionCond(nn.Module):
    method __init__ (line 186) | def __init__(self, d_model, d_kv, num_heads, dropout_rate):
    method forward (line 193) | def forward(
  class T5LayerCrossAttention (line 213) | class T5LayerCrossAttention(nn.Module):
    method __init__ (line 214) | def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_...
    method forward (line 220) | def forward(
  class T5LayerFFCond (line 236) | class T5LayerFFCond(nn.Module):
    method __init__ (line 237) | def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
    method forward (line 244) | def forward(self, hidden_states, conditioning_emb=None):
  class T5DenseGatedActDense (line 254) | class T5DenseGatedActDense(nn.Module):
    method __init__ (line 255) | def __init__(self, d_model, d_ff, dropout_rate):
    method forward (line 263) | def forward(self, hidden_states):
  class T5LayerNorm (line 273) | class T5LayerNorm(nn.Module):
    method __init__ (line 274) | def __init__(self, hidden_size, eps=1e-6):
    method forward (line 282) | def forward(self, hidden_states):
  class NewGELUActivation (line 298) | class NewGELUActivation(nn.Module):
    method forward (line 304) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class T5FiLMLayer (line 308) | class T5FiLMLayer(nn.Module):
    method __init__ (line 313) | def __init__(self, in_features, out_features):
    method forward (line 317) | def forward(self, x, conditioning_emb):

FILE: diffusers/models/transformer_2d.py
  class Transformer2DModelOutput (line 30) | class Transformer2DModelOutput(BaseOutput):
  class Transformer2DModel (line 41) | class Transformer2DModel(ModelMixin, ConfigMixin):
    method __init__ (line 80) | def __init__(
    method forward (line 214) | def forward(

FILE: diffusers/models/transformer_temporal.py
  class TransformerTemporalModelOutput (line 27) | class TransformerTemporalModelOutput(BaseOutput):
  class TransformerTemporalModel (line 37) | class TransformerTemporalModel(ModelMixin, ConfigMixin):
    method __init__ (line 60) | def __init__(
    method forward (line 106) | def forward(

FILE: diffusers/models/unet_1d.py
  class UNet1DOutput (line 29) | class UNet1DOutput(BaseOutput):
  class UNet1DModel (line 39) | class UNet1DModel(ModelMixin, ConfigMixin):
    method __init__ (line 73) | def __init__(
    method forward (line 193) | def forward(

FILE: diffusers/models/unet_1d_blocks.py
  class DownResnetBlock1D (line 23) | class DownResnetBlock1D(nn.Module):
    method __init__ (line 24) | def __init__(
    method forward (line 71) | def forward(self, hidden_states, temb=None):
  class UpResnetBlock1D (line 89) | class UpResnetBlock1D(nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 135) | def forward(self, hidden_states, res_hidden_states_tuple=None, temb=No...
  class ValueFunctionMidBlock1D (line 153) | class ValueFunctionMidBlock1D(nn.Module):
    method __init__ (line 154) | def __init__(self, in_channels, out_channels, embed_dim):
    method forward (line 165) | def forward(self, x, temb=None):
  class MidResTemporalBlock1D (line 173) | class MidResTemporalBlock1D(nn.Module):
    method __init__ (line 174) | def __init__(
    method forward (line 217) | def forward(self, hidden_states, temb):
  class OutConv1DBlock (line 230) | class OutConv1DBlock(nn.Module):
    method __init__ (line 231) | def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
    method forward (line 241) | def forward(self, hidden_states, temb=None):
  class OutValueFunctionBlock (line 251) | class OutValueFunctionBlock(nn.Module):
    method __init__ (line 252) | def __init__(self, fc_dim, embed_dim):
    method forward (line 262) | def forward(self, hidden_states, temb):
  class Downsample1d (line 291) | class Downsample1d(nn.Module):
    method __init__ (line 292) | def __init__(self, kernel="linear", pad_mode="reflect"):
    method forward (line 299) | def forward(self, hidden_states):
  class Upsample1d (line 307) | class Upsample1d(nn.Module):
    method __init__ (line 308) | def __init__(self, kernel="linear", pad_mode="reflect"):
    method forward (line 315) | def forward(self, hidden_states, temb=None):
  class SelfAttention1d (line 323) | class SelfAttention1d(nn.Module):
    method __init__ (line 324) | def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
    method transpose_for_scores (line 338) | def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
    method forward (line 344) | def forward(self, hidden_states):
  class ResConvBlock (line 381) | class ResConvBlock(nn.Module):
    method __init__ (line 382) | def __init__(self, in_channels, mid_channels, out_channels, is_last=Fa...
    method forward (line 399) | def forward(self, hidden_states):
  class UNetMidBlock1D (line 415) | class UNetMidBlock1D(nn.Module):
    method __init__ (line 416) | def __init__(self, mid_channels, in_channels, out_channels=None):
    method forward (line 444) | def forward(self, hidden_states, temb=None):
  class AttnDownBlock1D (line 455) | class AttnDownBlock1D(nn.Module):
    method __init__ (line 456) | def __init__(self, out_channels, in_channels, mid_channels=None):
    method forward (line 475) | def forward(self, hidden_states, temb=None):
  class DownBlock1D (line 485) | class DownBlock1D(nn.Module):
    method __init__ (line 486) | def __init__(self, out_channels, in_channels, mid_channels=None):
    method forward (line 499) | def forward(self, hidden_states, temb=None):
  class DownBlock1DNoSkip (line 508) | class DownBlock1DNoSkip(nn.Module):
    method __init__ (line 509) | def __init__(self, out_channels, in_channels, mid_channels=None):
    method forward (line 521) | def forward(self, hidden_states, temb=None):
  class AttnUpBlock1D (line 529) | class AttnUpBlock1D(nn.Module):
    method __init__ (line 530) | def __init__(self, in_channels, out_channels, mid_channels=None):
    method forward (line 549) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
  class UpBlock1D (line 562) | class UpBlock1D(nn.Module):
    method __init__ (line 563) | def __init__(self, in_channels, out_channels, mid_channels=None):
    method forward (line 576) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
  class UpBlock1DNoSkip (line 588) | class UpBlock1DNoSkip(nn.Module):
    method __init__ (line 589) | def __init__(self, in_channels, out_channels, mid_channels=None):
    method forward (line 601) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
  function get_down_block (line 611) | def get_down_block(down_block_type, num_layers, in_channels, out_channel...
  function get_up_block (line 629) | def get_up_block(up_block_type, num_layers, in_channels, out_channels, t...
  function get_mid_block (line 647) | def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels,...
  function get_out_block (line 663) | def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_chan...

FILE: diffusers/models/unet_2d.py
  class UNet2DOutput (line 28) | class UNet2DOutput(BaseOutput):
  class UNet2DModel (line 38) | class UNet2DModel(ModelMixin, ConfigMixin):
    method __init__ (line 83) | def __init__(
    method forward (line 219) | def forward(

FILE: diffusers/models/unet_2d_blocks.py
  function get_down_block (line 29) | def get_down_block(
  function get_up_block (line 210) | def get_up_block(
  class UNetMidBlock2D (line 391) | class UNetMidBlock2D(nn.Module):
    method __init__ (line 392) | def __init__(
    method forward (line 465) | def forward(self, hidden_states, temb=None):
  class UNetMidBlock2DCrossAttn (line 475) | class UNetMidBlock2DCrossAttn(nn.Module):
    method __init__ (line 476) | def __init__(
    method forward (line 560) | def forward(
  class UNetMidBlock2DSimpleCrossAttn (line 576) | class UNetMidBlock2DSimpleCrossAttn(nn.Module):
    method __init__ (line 577) | def __init__(
    method forward (line 661) | def forward(
  class AttnDownBlock2D (line 681) | class AttnDownBlock2D(nn.Module):
    method __init__ (line 682) | def __init__(
    method forward (line 748) | def forward(self, hidden_states, temb=None, upsample_size=None):
  class CrossAttnDownBlock2D (line 765) | class CrossAttnDownBlock2D(nn.Module):
    method __init__ (line 766) | def __init__(
    method forward (line 852) | def forward(
  class DownBlock2D (line 911) | class DownBlock2D(nn.Module):
    method __init__ (line 912) | def __init__(
    method forward (line 963) | def forward(self, hidden_states, temb=None):
  class DownEncoderBlock2D (line 997) | class DownEncoderBlock2D(nn.Module):
    method __init__ (line 998) | def __init__(
    method forward (line 1046) | def forward(self, hidden_states):
  class AttnDownEncoderBlock2D (line 1057) | class AttnDownEncoderBlock2D(nn.Module):
    method __init__ (line 1058) | def __init__(
    method forward (line 1123) | def forward(self, hidden_states):
  class AttnSkipDownBlock2D (line 1135) | class AttnSkipDownBlock2D(nn.Module):
    method __init__ (line 1136) | def __init__(
    method forward (line 1211) | def forward(self, hidden_states, temb=None, skip_sample=None):
  class SkipDownBlock2D (line 1231) | class SkipDownBlock2D(nn.Module):
    method __init__ (line 1232) | def __init__(
    method forward (line 1291) | def forward(self, hidden_states, temb=None, skip_sample=None):
  class ResnetDownsampleBlock2D (line 1310) | class ResnetDownsampleBlock2D(nn.Module):
    method __init__ (line 1311) | def __init__(
    method forward (line 1374) | def forward(self, hidden_states, temb=None):
  class SimpleCrossAttnDownBlock2D (line 1408) | class SimpleCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1409) | def __init__(
    method forward (line 1503) | def forward(
  class KDownBlock2D (line 1549) | class KDownBlock2D(nn.Module):
    method __init__ (line 1550) | def __init__(
    method forward (line 1595) | def forward(self, hidden_states, temb=None):
  class KCrossAttnDownBlock2D (line 1627) | class KCrossAttnDownBlock2D(nn.Module):
    method __init__ (line 1628) | def __init__(
    method forward (line 1692) | def forward(
  class AttnUpBlock2D (line 1754) | class AttnUpBlock2D(nn.Module):
    method __init__ (line 1755) | def __init__(
    method forward (line 1817) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class CrossAttnUpBlock2D (line 1834) | class CrossAttnUpBlock2D(nn.Module):
    method __init__ (line 1835) | def __init__(
    method forward (line 1917) | def forward(
  class UpBlock2D (line 1982) | class UpBlock2D(nn.Module):
    method __init__ (line 1983) | def __init__(
    method forward (line 2030) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class UpDecoderBlock2D (line 2063) | class UpDecoderBlock2D(nn.Module):
    method __init__ (line 2064) | def __init__(
    method forward (line 2106) | def forward(self, hidden_states):
  class AttnUpDecoderBlock2D (line 2117) | class AttnUpDecoderBlock2D(nn.Module):
    method __init__ (line 2118) | def __init__(
    method forward (line 2177) | def forward(self, hidden_states):
  class AttnSkipUpBlock2D (line 2189) | class AttnSkipUpBlock2D(nn.Module):
    method __init__ (line 2190) | def __init__(
    method forward (line 2275) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, s...
  class SkipUpBlock2D (line 2303) | class SkipUpBlock2D(nn.Module):
    method __init__ (line 2304) | def __init__(
    method forward (line 2372) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, s...
  class ResnetUpsampleBlock2D (line 2398) | class ResnetUpsampleBlock2D(nn.Module):
    method __init__ (line 2399) | def __init__(
    method forward (line 2465) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class SimpleCrossAttnUpBlock2D (line 2498) | class SimpleCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2499) | def __init__(
    method forward (line 2595) | def forward(
  class KUpBlock2D (line 2648) | class KUpBlock2D(nn.Module):
    method __init__ (line 2649) | def __init__(
    method forward (line 2696) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class KCrossAttnUpBlock2D (line 2728) | class KCrossAttnUpBlock2D(nn.Module):
    method __init__ (line 2729) | def __init__(
    method forward (line 2812) | def forward(
  class KAttentionBlock (line 2879) | class KAttentionBlock(nn.Module):
    method __init__ (line 2896) | def __init__(
    method _to_3d (line 2939) | def _to_3d(self, hidden_states, height, weight):
    method _to_4d (line 2942) | def _to_4d(self, hidden_states, height, weight):
    method forward (line 2945) | def forward(

FILE: diffusers/models/unet_2d_blocks_flax.py
  class FlaxCrossAttnDownBlock2D (line 22) | class FlaxCrossAttnDownBlock2D(nn.Module):
    method setup (line 56) | def setup(self):
    method __call__ (line 89) | def __call__(self, hidden_states, temb, encoder_hidden_states, determi...
  class FlaxDownBlock2D (line 104) | class FlaxDownBlock2D(nn.Module):
    method setup (line 129) | def setup(self):
    method __call__ (line 147) | def __call__(self, hidden_states, temb, deterministic=True):
  class FlaxCrossAttnUpBlock2D (line 161) | class FlaxCrossAttnUpBlock2D(nn.Module):
    method setup (line 196) | def setup(self):
    method __call__ (line 230) | def __call__(self, hidden_states, res_hidden_states_tuple, temb, encod...
  class FlaxUpBlock2D (line 246) | class FlaxUpBlock2D(nn.Module):
    method setup (line 274) | def setup(self):
    method __call__ (line 294) | def __call__(self, hidden_states, res_hidden_states_tuple, temb, deter...
  class FlaxUNetMidBlock2DCrossAttn (line 309) | class FlaxUNetMidBlock2DCrossAttn(nn.Module):
    method setup (line 335) | def setup(self):
    method __call__ (line 371) | def __call__(self, hidden_states, temb, encoder_hidden_states, determi...

FILE: diffusers/models/unet_2d_condition.py
  class UNet2DConditionOutput (line 44) | class UNet2DConditionOutput(BaseOutput):
  class UNet2DConditionModel (line 54) | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
    method __init__ (line 133) | def __init__(
    method attn_processors (line 482) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 505) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
    method set_default_attn_processor (line 535) | def set_default_attn_processor(self):
    method set_attention_slice (line 541) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 606) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 610) | def forward(

FILE: diffusers/models/unet_2d_condition_flax.py
  class FlaxUNet2DConditionOutput (line 36) | class FlaxUNet2DConditionOutput(BaseOutput):
  class FlaxUNet2DConditionModel (line 47) | class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
    method init_weights (line 118) | def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
    method setup (line 130) | def setup(self):
    method __call__ (line 253) | def __call__(

FILE: diffusers/models/unet_3d_blocks.py
  function get_down_block (line 23) | def get_down_block(
  function get_up_block (line 79) | def get_up_block(
  class UNetMidBlock3DCrossAttn (line 135) | class UNetMidBlock3DCrossAttn(nn.Module):
    method __init__ (line 136) | def __init__(
    method forward (line 235) | def forward(
  class CrossAttnDownBlock3D (line 263) | class CrossAttnDownBlock3D(nn.Module):
    method __init__ (line 264) | def __init__(
    method forward (line 359) | def forward(
  class DownBlock3D (line 396) | class DownBlock3D(nn.Module):
    method __init__ (line 397) | def __init__(
    method forward (line 457) | def forward(self, hidden_states, temb=None, num_frames=1):
  class CrossAttnUpBlock3D (line 475) | class CrossAttnUpBlock3D(nn.Module):
    method __init__ (line 476) | def __init__(
    method forward (line 567) | def forward(
  class UpBlock3D (line 605) | class UpBlock3D(nn.Module):
    method __init__ (line 606) | def __init__(
    method forward (line 662) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...

FILE: diffusers/models/unet_3d_condition.py
  class UNet3DConditionOutput (line 44) | class UNet3DConditionOutput(BaseOutput):
  class UNet3DConditionModel (line 54) | class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoade...
    method __init__ (line 87) | def __init__(
    method attn_processors (line 256) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attention_slice (line 280) | def set_attention_slice(self, slice_size):
    method set_attn_processor (line 346) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
    method set_default_attn_processor (line 377) | def set_default_attn_processor(self):
    method _set_gradient_checkpointing (line 383) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 387) | def forward(

FILE: diffusers/models/vae.py
  class DecoderOutput (line 26) | class DecoderOutput(BaseOutput):
  class Encoder (line 38) | class Encoder(nn.Module):
    method __init__ (line 39) | def __init__(
    method forward (line 107) | def forward(self, x):
  class Decoder (line 151) | class Decoder(nn.Module):
    method __init__ (line 152) | def __init__(
    method forward (line 220) | def forward(self, z):
  class VectorQuantizer (line 270) | class VectorQuantizer(nn.Module):
    method __init__ (line 279) | def __init__(
    method remap_to_used (line 308) | def remap_to_used(self, inds):
    method unmap_to_all (line 322) | def unmap_to_all(self, inds):
    method forward (line 332) | def forward(self, z):
    method get_codebook_entry (line 366) | def get_codebook_entry(self, indices, shape):
  class DiagonalGaussianDistribution (line 384) | class DiagonalGaussianDistribution(object):
    method __init__ (line 385) | def __init__(self, parameters, deterministic=False):
    method sample (line 397) | def sample(self, generator: Optional[torch.Generator] = None) -> torch...
    method kl (line 405) | def kl(self, other=None):
    method nll (line 421) | def nll(self, sample, dims=[1, 2, 3]):
    method mode (line 427) | def mode(self):

FILE: diffusers/models/vae_flax.py
  class FlaxDecoderOutput (line 33) | class FlaxDecoderOutput(BaseOutput):
  class FlaxAutoencoderKLOutput (line 48) | class FlaxAutoencoderKLOutput(BaseOutput):
  class FlaxUpsample2D (line 61) | class FlaxUpsample2D(nn.Module):
    method setup (line 75) | def setup(self):
    method __call__ (line 84) | def __call__(self, hidden_states):
  class FlaxDownsample2D (line 95) | class FlaxDownsample2D(nn.Module):
    method setup (line 109) | def setup(self):
    method __call__ (line 118) | def __call__(self, hidden_states):
  class FlaxResnetBlock2D (line 125) | class FlaxResnetBlock2D(nn.Module):
    method setup (line 151) | def setup(self):
    method __call__ (line 185) | def __call__(self, hidden_states, deterministic=True):
  class FlaxAttentionBlock (line 202) | class FlaxAttentionBlock(nn.Module):
    method setup (line 222) | def setup(self):
    method transpose_for_scores (line 231) | def transpose_for_scores(self, projection):
    method __call__ (line 239) | def __call__(self, hidden_states):
  class FlaxDownEncoderBlock2D (line 274) | class FlaxDownEncoderBlock2D(nn.Module):
    method setup (line 302) | def setup(self):
    method __call__ (line 320) | def __call__(self, hidden_states, deterministic=True):
  class FlaxUpDecoderBlock2D (line 330) | class FlaxUpDecoderBlock2D(nn.Module):
    method setup (line 358) | def setup(self):
    method __call__ (line 376) | def __call__(self, hidden_states, deterministic=True):
  class FlaxUNetMidBlock2D (line 386) | class FlaxUNetMidBlock2D(nn.Module):
    method setup (line 411) | def setup(self):
    method __call__ (line 448) | def __call__(self, hidden_states, deterministic=True):
  class FlaxEncoder (line 457) | class FlaxEncoder(nn.Module):
    method setup (line 501) | def setup(self):
    method __call__ (line 550) | def __call__(self, sample, deterministic: bool = True):
  class FlaxDecoder (line 569) | class FlaxDecoder(nn.Module):
    method setup (line 612) | def setup(self):
    method __call__ (line 665) | def __call__(self, sample, deterministic: bool = True):
  class FlaxDiagonalGaussianDistribution (line 683) | class FlaxDiagonalGaussianDistribution(object):
    method __init__ (line 684) | def __init__(self, parameters, deterministic=False):
    method sample (line 694) | def sample(self, key):
    method kl (line 697) | def kl(self, other=None):
    method nll (line 709) | def nll(self, sample, axis=[1, 2, 3]):
    method mode (line 716) | def mode(self):
  class FlaxAutoencoderKL (line 721) | class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
    method setup (line 780) | def setup(self):
    method init_weights (line 817) | def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
    method encode (line 827) | def encode(self, sample, deterministic: bool = True, return_dict: bool...
    method decode (line 839) | def decode(self, latents, deterministic: bool = True, return_dict: boo...
    method __call__ (line 853) | def __call__(self, sample, sample_posterior=False, deterministic: bool...

FILE: diffusers/models/vq_model.py
  class VQEncoderOutput (line 27) | class VQEncoderOutput(BaseOutput):
  class VQModel (line 39) | class VQModel(ModelMixin, ConfigMixin):
    method __init__ (line 70) | def __init__(
    method encode (line 117) | def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQ...
    method decode (line 126) | def decode(
    method forward (line 142) | def forward(self, sample: torch.FloatTensor, return_dict: bool = True)...

FILE: diffusers/optimization.py
  class SchedulerType (line 30) | class SchedulerType(Enum):
  function get_constant_schedule (line 40) | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
  function get_constant_schedule_with_warmup (line 56) | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_s...
  function get_piecewise_constant_schedule (line 81) | def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: st...
  function get_linear_schedule_with_warmup (line 123) | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_tra...
  function get_cosine_schedule_with_warmup (line 152) | def get_cosine_schedule_with_warmup(
  function get_cosine_with_hard_restarts_schedule_with_warmup (line 186) | def get_cosine_with_hard_restarts_schedule_with_warmup(
  function get_polynomial_decay_schedule_with_warmup (line 221) | def get_polynomial_decay_schedule_with_warmup(
  function get_scheduler (line 282) | def get_scheduler(

FILE: diffusers/pipelines/alt_diffusion/__init__.py
  class AltDiffusionPipelineOutput (line 13) | class AltDiffusionPipelineOutput(BaseOutput):

FILE: diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
  class TransformationModelOutput (line 11) | class TransformationModelOutput(ModelOutput):
  class RobertaSeriesConfig (line 39) | class RobertaSeriesConfig(XLMRobertaConfig):
    method __init__ (line 40) | def __init__(
  class RobertaSeriesModelWithTransformation (line 58) | class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
    method __init__ (line 64) | def __init__(self, config):
    method forward (line 74) | def forward(

FILE: diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
  class AltDiffusionPipeline (line 55) | class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
    method __init__ (line 92) | def __init__(
    method enable_vae_slicing (line 182) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 191) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 198) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 207) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 214) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 239) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 268) | def _execution_device(self):
    method _encode_prompt (line 285) | def _encode_prompt(
    method run_safety_checker (line 431) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 445) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 460) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 477) | def check_inputs(
    method prepare_latents (line 524) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 543) | def __call__(

FILE: diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
  function preprocess (line 71) | def preprocess(image):
  class AltDiffusionImg2ImgPipeline (line 93) | class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoa...
    method __init__ (line 130) | def __init__(
    method enable_sequential_cpu_offload (line 220) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 245) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 274) | def _execution_device(self):
    method _encode_prompt (line 291) | def _encode_prompt(
    method run_safety_checker (line 437) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 451) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 466) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 483) | def check_inputs(
    method get_timesteps (line 523) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 532) | def prepare_latents(self, image, timestep, batch_size, num_images_per_...
    method __call__ (line 586) | def __call__(

FILE: diffusers/pipelines/audio_diffusion/mel.py
  class Mel (line 37) | class Mel(ConfigMixin, SchedulerMixin):
    method __init__ (line 52) | def __init__(
    method set_resolution (line 73) | def set_resolution(self, x_res: int, y_res: int):
    method load_audio (line 85) | def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = N...
    method get_number_of_slices (line 101) | def get_number_of_slices(self) -> int:
    method get_audio_slice (line 109) | def get_audio_slice(self, slice: int = 0) -> np.ndarray:
    method get_sample_rate (line 120) | def get_sample_rate(self) -> int:
    method audio_slice_to_image (line 128) | def audio_slice_to_image(self, slice: int) -> Image.Image:
    method image_to_audio (line 145) | def image_to_audio(self, image: Image.Image) -> np.ndarray:

FILE: diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py
  class AudioDiffusionPipeline (line 30) | class AudioDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 44) | def __init__(
    method get_default_steps (line 54) | def get_default_steps(self) -> int:
    method __call__ (line 63) | def __call__(
    method encode (line 199) | def encode(self, images: List[Image.Image], steps: int = 50) -> np.nda...
    method slerp (line 236) | def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.T...

FILE: diffusers/pipelines/audioldm/pipeline_audioldm.py
  class AudioLDMPipeline (line 46) | class AudioLDMPipeline(DiffusionPipeline):
    method __init__ (line 72) | def __init__(
    method enable_vae_slicing (line 94) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 104) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 111) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 129) | def _execution_device(self):
    method _encode_prompt (line 146) | def _encode_prompt(
    method decode_latents (line 285) | def decode_latents(self, latents):
    method mel_spectrogram_to_waveform (line 290) | def mel_spectrogram_to_waveform(self, mel_spectrogram):
    method prepare_extra_step_kwargs (line 300) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 317) | def check_inputs(
    method prepare_latents (line 376) | def prepare_latents(self, batch_size, num_channels_latents, height, dt...
    method __call__ (line 400) | def __call__(

FILE: diffusers/pipelines/controlnet/multicontrolnet.py
  class MultiControlNetModel (line 10) | class MultiControlNetModel(ModelMixin):
    method __init__ (line 23) | def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[Con...
    method forward (line 27) | def forward(

FILE: diffusers/pipelines/controlnet/pipeline_controlnet.py
  class StableDiffusionControlNetPipeline (line 95) | class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInvers...
    method __init__ (line 131) | def __init__(
    method enable_vae_slicing (line 179) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 189) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 197) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 207) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 214) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 235) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 265) | def _execution_device(self):
    method _encode_prompt (line 283) | def _encode_prompt(
    method run_safety_checker (line 430) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 445) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 459) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 476) | def check_inputs(
    method check_image (line 592) | def check_image(self, image, prompt, prompt_embeds):
    method prepare_image (line 624) | def prepare_image(
    method prepare_latents (line 677) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method _default_height_width (line 694) | def _default_height_width(self, height, width, image):
    method save_pretrained (line 720) | def save_pretrained(
    method __call__ (line 733) | def __call__(

FILE: diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
  function prepare_image (line 97) | def prepare_image(image):
  class StableDiffusionControlNetImg2ImgPipeline (line 121) | class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, Textua...
    method __init__ (line 157) | def __init__(
    method enable_vae_slicing (line 205) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 215) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 223) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 233) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 240) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 261) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 291) | def _execution_device(self):
    method _encode_prompt (line 309) | def _encode_prompt(
    method run_safety_checker (line 456) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 471) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 485) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 502) | def check_inputs(
    method check_image (line 618) | def check_image(self, image, prompt, prompt_embeds):
    method prepare_control_image (line 651) | def prepare_control_image(
    method get_timesteps (line 704) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 714) | def prepare_latents(self, image, timestep, batch_size, num_images_per_...
    method _default_height_width (line 766) | def _default_height_width(self, height, width, image):
    method save_pretrained (line 792) | def save_pretrained(
    method __call__ (line 805) | def __call__(

FILE: diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
  function prepare_mask_and_masked_image (line 102) | def prepare_mask_and_masked_image(image, mask, height, width, return_ima...
  class StableDiffusionControlNetInpaintPipeline (line 219) | class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, Textua...
    method __init__ (line 255) | def __init__(
    method enable_vae_slicing (line 303) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 313) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 321) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 331) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 338) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 359) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 389) | def _execution_device(self):
    method _encode_prompt (line 407) | def _encode_prompt(
    method run_safety_checker (line 554) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 569) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 583) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 600) | def check_inputs(
    method check_image (line 716) | def check_image(self, image, prompt, prompt_embeds):
    method prepare_control_image (line 749) | def prepare_control_image(
    method prepare_latents (line 802) | def prepare_latents(
    method _default_height_width (line 856) | def _default_height_width(self, height, width, image):
    method prepare_mask_latents (line 882) | def prepare_mask_latents(
    method save_pretrained (line 934) | def save_pretrained(
    method __call__ (line 947) | def __call__(

FILE: diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
  class FlaxStableDiffusionControlNetPipeline (line 119) | class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
    method __init__ (line 150) | def __init__(
    method prepare_text_inputs (line 189) | def prepare_text_inputs(self, prompt: Union[str, List[str]]):
    method prepare_image_inputs (line 203) | def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Im...
    method _get_has_nsfw_concepts (line 214) | def _get_has_nsfw_concepts(self, features, params):
    method _run_safety_checker (line 218) | def _run_safety_checker(self, images, safety_model_params, jit=False):
    method _generate (line 248) | def _generate(
    method __call__ (line 358) | def __call__(
  function _p_generate (line 493) | def _p_generate(
  function _p_get_has_nsfw_concepts (line 519) | def _p_get_has_nsfw_concepts(pipe, features, params):
  function unshard (line 523) | def unshard(x: jnp.ndarray):
  function preprocess (line 530) | def preprocess(image, dtype):

FILE: diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
  class DanceDiffusionPipeline (line 27) | class DanceDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 39) | def __init__(self, unet, scheduler):
    method __call__ (line 44) | def __call__(

FILE: diffusers/pipelines/ddim/pipeline_ddim.py
  class DDIMPipeline (line 24) | class DDIMPipeline(DiffusionPipeline):
    method __init__ (line 36) | def __init__(self, unet, scheduler):
    method __call__ (line 45) | def __call__(

FILE: diffusers/pipelines/ddpm/pipeline_ddpm.py
  class DDPMPipeline (line 24) | class DDPMPipeline(DiffusionPipeline):
    method __init__ (line 36) | def __init__(self, unet, scheduler):
    method __call__ (line 41) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/__init__.py
  class IFPipelineOutput (line 21) | class IFPipelineOutput(BaseOutput):

FILE: diffusers/pipelines/deepfloyd_if/pipeline_if.py
  class IFPipeline (line 89) | class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
    method __init__ (line 107) | def __init__(
    method enable_sequential_cpu_offload (line 147) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 171) | def enable_model_cpu_offload(self, gpu_id=0):
    method remove_all_hooks (line 212) | def remove_all_hooks(self):
    method _execution_device (line 228) | def _execution_device(self):
    method encode_prompt (line 246) | def encode_prompt(
    method run_safety_checker (line 397) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 414) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 431) | def check_inputs(
    method prepare_intermediate_images (line 473) | def prepare_intermediate_images(self, batch_size, num_channels, height...
    method _text_preprocessing (line 487) | def _text_preprocessing(self, text, clean_caption=False):
    method _clean_caption (line 511) | def _clean_caption(self, caption):
    method __call__ (line 627) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
  function resize (line 41) | def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
  class IFImg2ImgPipeline (line 113) | class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
    method __init__ (line 131) | def __init__(
    method enable_sequential_cpu_offload (line 172) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 197) | def enable_model_cpu_offload(self, gpu_id=0):
    method remove_all_hooks (line 239) | def remove_all_hooks(self):
    method _execution_device (line 255) | def _execution_device(self):
    method encode_prompt (line 274) | def encode_prompt(
    method run_safety_checker (line 426) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 443) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 460) | def check_inputs(
    method _text_preprocessing (line 534) | def _text_preprocessing(self, text, clean_caption=False):
    method _clean_caption (line 559) | def _clean_caption(self, caption):
    method preprocess_image (line 673) | def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
    method get_timesteps (line 709) | def get_timesteps(self, num_inference_steps, strength):
    method prepare_intermediate_images (line 718) | def prepare_intermediate_images(
    method __call__ (line 742) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
  function resize (line 43) | def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
  class IFImg2ImgSuperResolutionPipeline (line 115) | class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
    method __init__ (line 134) | def __init__(
    method enable_sequential_cpu_offload (line 182) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 207) | def enable_model_cpu_offload(self, gpu_id=0):
    method remove_all_hooks (line 249) | def remove_all_hooks(self):
    method _text_preprocessing (line 264) | def _text_preprocessing(self, text, clean_caption=False):
    method _clean_caption (line 289) | def _clean_caption(self, caption):
    method _execution_device (line 405) | def _execution_device(self):
    method encode_prompt (line 424) | def encode_prompt(
    method run_safety_checker (line 576) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 593) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 610) | def check_inputs(
    method preprocess_original_image (line 720) | def preprocess_original_image(self, image: PIL.Image.Image) -> torch.T...
    method preprocess_image (line 757) | def preprocess_image(self, image: PIL.Image.Image, num_images_per_prom...
    method get_timesteps (line 789) | def get_timesteps(self, num_inference_steps, strength):
    method prepare_intermediate_images (line 799) | def prepare_intermediate_images(
    method __call__ (line 823) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
  function resize (line 42) | def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
  class IFInpaintingPipeline (line 116) | class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
    method __init__ (line 134) | def __init__(
    method enable_sequential_cpu_offload (line 175) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 200) | def enable_model_cpu_offload(self, gpu_id=0):
    method remove_all_hooks (line 242) | def remove_all_hooks(self):
    method _execution_device (line 258) | def _execution_device(self):
    method encode_prompt (line 277) | def encode_prompt(
    method run_safety_checker (line 429) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 446) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 463) | def check_inputs(
    method _text_preprocessing (line 573) | def _text_preprocessing(self, text, clean_caption=False):
    method _clean_caption (line 598) | def _clean_caption(self, caption):
    method preprocess_image (line 713) | def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
    method preprocess_mask_image (line 749) | def preprocess_mask_image(self, mask_image) -> torch.Tensor:
    method get_timesteps (line 799) | def get_timesteps(self, num_inference_steps, strength):
    method prepare_intermediate_images (line 808) | def prepare_intermediate_images(
    method __call__ (line 834) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
  function resize (line 43) | def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
  class IFInpaintingSuperResolutionPipeline (line 117) | class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
    method __init__ (line 136) | def __init__(
    method enable_sequential_cpu_offload (line 184) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 209) | def enable_model_cpu_offload(self, gpu_id=0):
    method remove_all_hooks (line 251) | def remove_all_hooks(self):
    method _text_preprocessing (line 266) | def _text_preprocessing(self, text, clean_caption=False):
    method _clean_caption (line 291) | def _clean_caption(self, caption):
    method _execution_device (line 407) | def _execution_device(self):
    method encode_prompt (line 426) | def encode_prompt(
    method run_safety_checker (line 578) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 595) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 612) | def check_inputs(
    method preprocess_original_image (line 756) | def preprocess_original_image(self, image: PIL.Image.Image) -> torch.T...
    method preprocess_image (line 793) | def preprocess_image(self, image: PIL.Image.Image, num_images_per_prom...
    method preprocess_mask_image (line 825) | def preprocess_mask_image(self, mask_image) -> torch.Tensor:
    method get_timesteps (line 875) | def get_timesteps(self, num_inference_steps, strength):
    method prepare_intermediate_images (line 885) | def prepare_intermediate_images(
    method __call__ (line 911) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
  class IFSuperResolutionPipeline (line 73) | class IFSuperResolutionPipeline(DiffusionPipeline):
    method __init__ (line 92) | def __init__(
    method enable_sequential_cpu_offload (line 140) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 165) | def enable_model_cpu_offload(self, gpu_id=0):
    method remove_all_hooks (line 207) | def remove_all_hooks(self):
    method _text_preprocessing (line 222) | def _text_preprocessing(self, text, clean_caption=False):
    method _clean_caption (line 247) | def _clean_caption(self, caption):
    method _execution_device (line 363) | def _execution_device(self):
    method encode_prompt (line 382) | def encode_prompt(
    method run_safety_checker (line 534) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 551) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 568) | def check_inputs(
    method prepare_intermediate_images (line 648) | def prepare_intermediate_images(self, batch_size, num_channels, height...
    method preprocess_image (line 662) | def preprocess_image(self, image, num_images_per_prompt, device):
    method __call__ (line 695) | def __call__(

FILE: diffusers/pipelines/deepfloyd_if/safety_checker.py
  class IFSafetyChecker (line 12) | class IFSafetyChecker(PreTrainedModel):
    method __init__ (line 17) | def __init__(self, config: CLIPConfig):
    method forward (line 26) | def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):

FILE: diffusers/pipelines/deepfloyd_if/watermark.py
  class IFWatermarker (line 12) | class IFWatermarker(ModelMixin, ConfigMixin):
    method __init__ (line 13) | def __init__(self):
    method apply_watermark (line 19) | def apply_watermark(self, images: List[PIL.Image.Image], sample_size=N...

FILE: diffusers/pipelines/dit/pipeline_dit.py
  class DiTPipeline (line 31) | class DiTPipeline(DiffusionPipeline):
    method __init__ (line 45) | def __init__(
    method get_label_ids (line 63) | def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
    method __call__ (line 87) | def __call__(

FILE: diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
  class LDMTextToImagePipeline (line 32) | class LDMTextToImagePipeline(DiffusionPipeline):
    method __init__ (line 51) | def __init__(
    method __call__ (line 64) | def __call__(
  class LDMBertConfig (line 220) | class LDMBertConfig(PretrainedConfig):
    method __init__ (line 225) | def __init__(
  function _expand_mask (line 267) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
  class LDMBertAttention (line 282) | class LDMBertAttention(nn.Module):
    method __init__ (line 285) | def __init__(
    method _shape (line 309) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    method forward (line 312) | def forward(
  class LDMBertEncoderLayer (line 426) | class LDMBertEncoderLayer(nn.Module):
    method __init__ (line 427) | def __init__(self, config: LDMBertConfig):
    method forward (line 444) | def forward(
  class LDMBertPreTrainedModel (line 496) | class LDMBertPreTrainedModel(PreTrainedModel):
    method _init_weights (line 502) | def _init_weights(self, module):
    method _set_gradient_checkpointing (line 513) | def _set_gradient_checkpointing(self, module, value=False):
    method dummy_inputs (line 518) | def dummy_inputs(self):
  class LDMBertEncoder (line 528) | class LDMBertEncoder(LDMBertPreTrainedModel):
    method __init__ (line 538) | def __init__(self, config: LDMBertConfig):
    method get_input_embeddings (line 556) | def get_input_embeddings(self):
    method set_input_embeddings (line 559) | def set_input_embeddings(self, value):
    method forward (line 562) | def forward(
  class LDMBertModel (line 695) | class LDMBertModel(LDMBertPreTrainedModel):
    method __init__ (line 698) | def __init__(self, config: LDMBertConfig):
    method forward (line 703) | def forward(

FILE: diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
  function preprocess (line 22) | def preprocess(image):
  class LDMSuperResolutionPipeline (line 32) | class LDMSuperResolutionPipeline(DiffusionPipeline):
    method __init__ (line 49) | def __init__(
    method __call__ (line 66) | def __call__(

FILE: diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
  class LDMPipeline (line 26) | class LDMPipeline(DiffusionPipeline):
    method __init__ (line 39) | def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMS...
    method __call__ (line 44) | def __call__(

FILE: diffusers/pipelines/onnx_utils.py
  class OnnxRuntimeModel (line 51) | class OnnxRuntimeModel:
    method __init__ (line 52) | def __init__(self, model=None, **kwargs):
    method __call__ (line 58) | def __call__(self, **kwargs):
    method load_model (line 63) | def load_model(path: Union[str, Path], provider=None, sess_options=None):
    method _save_pretrained (line 79) | def _save_pretrained(self, save_directory: Union[str, Path], file_name...
    method save_pretrained (line 110) | def save_pretrained(
    method _from_pretrained (line 133) | def _from_pretrained(
    method from_pretrained (line 193) | def from_pretrained(

FILE: diffusers/pipelines/paint_by_example/image_encoder.py
  class PaintByExampleImageEncoder (line 25) | class PaintByExampleImageEncoder(CLIPPreTrainedModel):
    method __init__ (line 26) | def __init__(self, config, proj_size=768):
    method forward (line 38) | def forward(self, pixel_values, return_uncond_vector=False):
  class PaintByExampleMapper (line 50) | class PaintByExampleMapper(nn.Module):
    method __init__ (line 51) | def __init__(self, config):
    method forward (line 63) | def forward(self, hidden_states):

FILE: diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
  function prepare_mask_and_masked_image (line 39) | def prepare_mask_and_masked_image(image, mask):
  class PaintByExamplePipeline (line 139) | class PaintByExamplePipeline(DiffusionPipeline):
    method __init__ (line 168) | def __init__(
    method enable_sequential_cpu_offload (line 192) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 213) | def _execution_device(self):
    method run_safety_checker (line 231) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 246) | def prepare_extra_step_kwargs(self, generator, eta):
    method decode_latents (line 264) | def decode_latents(self, latents):
    method check_inputs (line 278) | def check_inputs(self, image, height, width, callback_steps):
    method prepare_latents (line 301) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method prepare_mask_latents (line 319) | def prepare_mask_latents(
    method _encode_image (line 370) | def _encode_image(self, image, device, num_images_per_prompt, do_class...
    method __call__ (line 396) | def __call__(

FILE: diffusers/pipelines/pipeline_flax_utils.py
  function import_flax_or_no_model (line 66) | def import_flax_or_no_model(module, class_name):
  class FlaxImagePipelineOutput (line 80) | class FlaxImagePipelineOutput(BaseOutput):
  class FlaxDiffusionPipeline (line 93) | class FlaxDiffusionPipeline(ConfigMixin):
    method register_modules (line 110) | def register_modules(self, **kwargs):
    method save_pretrained (line 143) | def save_pretrained(self, save_directory: Union[str, os.PathLike], par...
    method from_pretrained (line 194) | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union...
    method _get_signature_keys (line 496) | def _get_signature_keys(obj):
    method components (line 504) | def components(self) -> Dict[str, Any]:
    method numpy_to_pil (line 541) | def numpy_to_pil(images):
    method progress_bar (line 557) | def progress_bar(self, iterable):
    method set_progress_bar_config (line 567) | def set_progress_bar_config(self, **kwargs):

FILE: diffusers/pipelines/pipeline_utils.py
  class ImagePipelineOutput (line 111) | class ImagePipelineOutput(BaseOutput):
  class AudioPipelineOutput (line 125) | class AudioPipelineOutput(BaseOutput):
  function is_safetensors_compatible (line 138) | def is_safetensors_compatible(filenames, variant=None, passed_components...
  function variant_compatible_siblings (line 187) | def variant_compatible_siblings(filenames, variant=None) -> Union[List[o...
  function warn_deprecated_model_variant (line 254) | def warn_deprecated_model_variant(pretrained_model_name_or_path, use_aut...
  function maybe_raise_or_warn (line 276) | def maybe_raise_or_warn(
  function get_class_obj_and_candidates (line 308) | def get_class_obj_and_candidates(library_name, class_name, importable_cl...
  function _get_pipeline_class (line 325) | def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_d...
  function load_sub_model (line 346) | def load_sub_model(
  class DiffusionPipeline (line 452) | class DiffusionPipeline(ConfigMixin):
    method register_modules (line 472) | def register_modules(self, **kwargs):
    method __setattr__ (line 509) | def __setattr__(self, name: str, value: Any):
    method save_pretrained (line 524) | def save_pretrained(
    method to (line 611) | def to(
    method device (line 687) | def device(self) -> torch.device:
    method from_pretrained (line 702) | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union...
    method download (line 1084) | def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.Pa...
    method _get_signature_keys (line 1348) | def _get_signature_keys(obj):
    method components (line 1356) | def components(self) -> Dict[str, Any]:
    method numpy_to_pil (line 1393) | def numpy_to_pil(images):
    method progress_bar (line 1399) | def progress_bar(self, iterable=None, total=None):
    method set_progress_bar_config (line 1414) | def set_progress_bar_config(self, **kwargs):
    method enable_xformers_memory_efficient_attention (line 1417) | def enable_xformers_memory_efficient_attention(self, attention_op: Opt...
    method disable_xformers_memory_efficient_attention (line 1449) | def disable_xformers_memory_efficient_attention(self):
    method set_use_memory_efficient_attention_xformers (line 1455) | def set_use_memory_efficient_attention_xformers(
    method enable_attention_slicing (line 1475) | def enable_attention_slicing(self, slice_size: Optional[Union[str, int...
    method disable_attention_slicing (line 1491) | def disable_attention_slicing(self):
    method set_attention_slice (line 1499) | def set_attention_slice(self, slice_size: Optional[int]):

FILE: diffusers/pipelines/pndm/pipeline_pndm.py
  class PNDMPipeline (line 26) | class PNDMPipeline(DiffusionPipeline):
    method __init__ (line 40) | def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
    method __call__ (line 48) | def __call__(

FILE: diffusers/pipelines/repaint/pipeline_repaint.py
  function _preprocess_image (line 32) | def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
  function _preprocess_mask (line 53) | def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]):
  class RePaintPipeline (line 73) | class RePaintPipeline(DiffusionPipeline):
    method __init__ (line 77) | def __init__(self, unet, scheduler):
    method __call__ (line 82) | def __call__(

FILE: diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
  class ScoreSdeVePipeline (line 25) | class ScoreSdeVePipeline(DiffusionPipeline):
    method __init__ (line 36) | def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler):
    method __call__ (line 41) | def __call__(

FILE: diffusers/pipelines/semantic_stable_diffusion/__init__.py
  class SemanticStableDiffusionPipelineOutput (line 13) | class SemanticStableDiffusionPipelineOutput(BaseOutput):

FILE: diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
  class SemanticStableDiffusionPipeline (line 63) | class SemanticStableDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 95) | def __init__(
    method run_safety_checker (line 138) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 153) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 167) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 185) | def check_inputs(
    method prepare_latents (line 233) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 251) | def __call__(

FILE: diffusers/pipelines/spectrogram_diffusion/continous_encoder.py
  class SpectrogramContEncoder (line 29) | class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
    method __init__ (line 31) | def __init__(
    method forward (line 71) | def forward(self, encoder_inputs, encoder_inputs_mask):

FILE: diffusers/pipelines/spectrogram_diffusion/midi_utils.py
  class NoteRepresentationConfig (line 84) | class NoteRepresentationConfig:
  class NoteEventData (line 92) | class NoteEventData:
  class NoteEncodingState (line 101) | class NoteEncodingState:
  class EventRange (line 109) | class EventRange:
  class Event (line 116) | class Event:
  class Tokenizer (line 121) | class Tokenizer:
    method __init__ (line 122) | def __init__(self, regular_ids: int):
    method encode (line 127) | def encode(self, token_ids):
  class Codec (line 145) | class Codec:
    method __init__ (line 156) | def __init__(self, max_shift_steps: int, steps_per_second: float, even...
    method num_classes (line 172) | def num_classes(self) -> int:
    method is_shift_event_index (line 178) | def is_shift_event_index(self, index: int) -> bool:
    method max_shift_steps (line 182) | def max_shift_steps(self) -> int:
    method encode_event (line 185) | def encode_event(self, event: Event) -> int:
    method event_type_range (line 200) | def event_type_range(self, event_type: str) -> Tuple[int, int]:
    method decode_event_index (line 210) | def decode_event_index(self, index: int) -> Event:
  class ProgramGranularity (line 222) | class ProgramGranularity:
  function drop_programs (line 228) | def drop_programs(tokens, codec: Codec):
  function programs_to_midi_classes (line 234) | def programs_to_midi_classes(tokens, codec):
  function frame (line 254) | def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, ...
  function program_to_slakh_program (line 272) | def program_to_slakh_program(program):
  function audio_to_frames (line 279) | def audio_to_frames(
  function note_sequence_to_onsets_and_offsets_and_programs (line 302) | def note_sequence_to_onsets_and_offsets_and_programs(
  function num_velocity_bins_from_codec (line 332) | def num_velocity_bins_from_codec(codec: Codec):
  function segment (line 339) | def segment(a, n):
  function velocity_to_bin (line 343) | def velocity_to_bin(velocity, num_velocity_bins):
  function note_event_data_to_events (line 350) | def note_event_data_to_events(
  function note_encoding_state_to_events (line 382) | def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[...
  function encode_and_index_events (line 392) | def encode_and_index_events(
  function extract_sequence_with_indices (line 498) | def extract_sequence_with_indices(features, state_events_end_token=None,...
  function map_midi_programs (line 524) | def map_midi_programs(
  function run_length_encode_shifts_fn (line 534) | def run_length_encode_shifts_fn(
  function note_representation_processor_chain (line 604) | def note_representation_processor_chain(features, codec: Codec, note_rep...
  class MidiProcessor (line 619) | class MidiProcessor:
    method __init__ (line 620) | def __init__(self):
    method __call__ (line 635) | def __call__(self, midi: Union[bytes, os.PathLike, str]):

FILE: diffusers/pipelines/spectrogram_diffusion/notes_encoder.py
  class SpectrogramNotesEncoder (line 25) | class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
    method __init__ (line 27) | def __init__(
    method forward (line 69) | def forward(self, encoder_input_tokens, encoder_inputs_mask):

FILE: diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
  class SpectrogramDiffusionPipeline (line 40) | class SpectrogramDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 43) | def __init__(
    method scale_features (line 66) | def scale_features(self, features, output_range=(-1.0, 1.0), clip=False):
    method scale_to_features (line 76) | def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=Fal...
    method encode (line 85) | def encode(self, input_tokens, continuous_inputs, continuous_mask):
    method decode (line 97) | def decode(self, encodings_and_masks, input_tokens, noise_time):
    method __call__ (line 113) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/__init__.py
  class StableDiffusionPipelineOutput (line 22) | class StableDiffusionPipelineOutput(BaseOutput):
  class FlaxStableDiffusionPipelineOutput (line 115) | class FlaxStableDiffusionPipelineOutput(BaseOutput):

FILE: diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
  function shave_segments (line 63) | def shave_segments(path, n_shave_prefix_segments=1):
  function renew_resnet_paths (line 73) | def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
  function renew_vae_resnet_paths (line 95) | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
  function renew_attention_paths (line 111) | def renew_attention_paths(old_list, n_shave_prefix_segments=0):
  function renew_vae_attention_paths (line 132) | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
  function assign_to_checkpoint (line 162) | def assign_to_checkpoint(
  function conv_attn_to_linear (line 217) | def conv_attn_to_linear(checkpoint):
  function create_unet_diffusers_config (line 229) | def create_unet_diffusers_config(original_config, image_size: int, contr...
  function create_vae_diffusers_config (line 298) | def create_vae_diffusers_config(original_config, image_size: int):
  function create_diffusers_schedular (line 322) | def create_diffusers_schedular(original_config):
  function create_ldm_bert_config (line 332) | def create_ldm_bert_config(original_config):
  function convert_ldm_unet_checkpoint (line 342) | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_e...
  function convert_ldm_vae_checkpoint (line 573) | def convert_ldm_vae_checkpoint(checkpoint, config):
  function convert_ldm_bert_checkpoint (line 680) | def convert_ldm_bert_checkpoint(checkpoint, config):
  function convert_ldm_clip_checkpoint (line 730) | def convert_ldm_clip_checkpoint(checkpoint):
  function convert_paint_by_example_checkpoint (line 770) | def convert_paint_by_example_checkpoint(checkpoint):
  function convert_open_clip_checkpoint (line 837) | def convert_open_clip_checkpoint(checkpoint):
  function stable_unclip_image_encoder (line 880) | def stable_unclip_image_encoder(original_config):
  function stable_unclip_image_noising_components (line 913) | def stable_unclip_image_noising_components(
  function convert_controlnet_checkpoint (line 958) | def convert_controlnet_checkpoint(
  function download_from_original_stable_diffusion_ckpt (line 977) | def download_from_original_stable_diffusion_ckpt(
  function download_controlnet_from_original_ckpt (line 1335) | def download_controlnet_from_original_ckpt(

FILE: diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
  function preprocess (line 42) | def preprocess(image):
  function posterior_sample (line 63) | def posterior_sample(scheduler, latents, timestep, clean_latents, genera...
  function compute_noise (line 90) | def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred...
  class CycleDiffusionPipeline (line 124) | class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMi...
    method __init__ (line 153) | def __init__(
    method enable_sequential_cpu_offload (line 230) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 256) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 286) | def _execution_device(self):
    method _encode_prompt (line 304) | def _encode_prompt(
    method check_inputs (line 451) | def check_inputs(
    method prepare_extra_step_kwargs (line 492) | def prepare_extra_step_kwargs(self, generator, eta):
    method run_safety_checker (line 510) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 525) | def decode_latents(self, latents):
    method get_timesteps (line 539) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 548) | def prepare_latents(self, image, timestep, batch_size, num_images_per_...
    method __call__ (line 598) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
  class FlaxStableDiffusionPipeline (line 81) | class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
    method __init__ (line 110) | def __init__(
    method prepare_inputs (line 168) | def prepare_inputs(self, prompt: Union[str, List[str]]):
    method _get_has_nsfw_concepts (line 181) | def _get_has_nsfw_concepts(self, features, params):
    method _run_safety_checker (line 185) | def _run_safety_checker(self, images, safety_model_params, jit=False):
    method _generate (line 215) | def _generate(
    method __call__ (line 312) | def __call__(
  function _p_generate (line 436) | def _p_generate(
  function _p_get_has_nsfw_concepts (line 462) | def _p_get_has_nsfw_concepts(pipe, features, params):
  function unshard (line 466) | def unshard(x: jnp.ndarray):

FILE: diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
  class FlaxStableDiffusionImg2ImgPipeline (line 105) | class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
    method __init__ (line 134) | def __init__(
    method prepare_inputs (line 171) | def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[I...
    method _get_has_nsfw_concepts (line 192) | def _get_has_nsfw_concepts(self, features, params):
    method _run_safety_checker (line 196) | def _run_safety_checker(self, images, safety_model_params, jit=False):
    method get_timestep_start (line 226) | def get_timestep_start(self, num_inference_steps, strength):
    method _generate (line 234) | def _generate(
    method __call__ (line 339) | def __call__(
  function _p_generate (line 480) | def _p_generate(
  function _p_get_has_nsfw_concepts (line 510) | def _p_get_has_nsfw_concepts(pipe, features, params):
  function unshard (line 514) | def unshard(x: jnp.ndarray):
  function preprocess (line 521) | def preprocess(image, dtype):

FILE: diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
  class FlaxStableDiffusionInpaintPipeline (line 102) | class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
    method __init__ (line 131) | def __init__(
    method prepare_inputs (line 189) | def prepare_inputs(
    method _get_has_nsfw_concepts (line 228) | def _get_has_nsfw_concepts(self, features, params):
    method _run_safety_checker (line 232) | def _run_safety_checker(self, images, safety_model_params, jit=False):
    method _generate (line 262) | def _generate(
    method __call__ (line 390) | def __call__(
  function _p_generate (line 523) | def _p_generate(
  function _p_get_has_nsfw_concepts (line 553) | def _p_get_has_nsfw_concepts(pipe, features, params):
  function unshard (line 557) | def unshard(x: jnp.ndarray):
  function preprocess_image (line 564) | def preprocess_image(image, dtype):
  function preprocess_mask (line 573) | def preprocess_mask(mask, dtype):

FILE: diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
  class OnnxStableDiffusionPipeline (line 33) | class OnnxStableDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 45) | def __init__(
    method _encode_prompt (line 114) | def _encode_prompt(
    method check_inputs (line 217) | def check_inputs(
    method __call__ (line 264) | def __call__(
  class StableDiffusionOnnxPipeline (line 462) | class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
    method __init__ (line 463) | def __init__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
  function preprocess (line 35) | def preprocess(image):
  class OnnxStableDiffusionImg2ImgPipeline (line 56) | class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
    method __init__ (line 94) | def __init__(
    method _encode_prompt (line 164) | def _encode_prompt(
    method check_inputs (line 267) | def check_inputs(
    method __call__ (line 309) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
  function prepare_mask_and_masked_image (line 38) | def prepare_mask_and_masked_image(image, mask, latents_shape):
  class OnnxStableDiffusionInpaintPipeline (line 56) | class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
    method __init__ (line 94) | def __init__(
    method _encode_prompt (line 165) | def _encode_prompt(
    method check_inputs (line 269) | def check_inputs(
    method __call__ (line 317) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
  function preprocess (line 20) | def preprocess(image):
  function preprocess_mask (line 29) | def preprocess_mask(mask, scale_factor=8):
  class OnnxStableDiffusionInpaintPipelineLegacy (line 41) | class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
    method __init__ (line 80) | def __init__(
    method _encode_prompt (line 150) | def _encode_prompt(
    method check_inputs (line 253) | def check_inputs(
    method __call__ (line 295) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
  function preprocess (line 26) | def preprocess(image):
  class OnnxStableDiffusionUpscalePipeline (line 48) | class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
    method __init__ (line 49) | def __init__(
    method __call__ (line 72) | def __call__(
    method decode_latents (line 281) | def decode_latents(self, latents):
    method _encode_prompt (line 288) | def _encode_prompt(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
  class StableDiffusionPipeline (line 58) | class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderM...
    method __init__ (line 95) | def __init__(
    method enable_vae_slicing (line 185) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 194) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 201) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 210) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 217) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 242) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 271) | def _execution_device(self):
    method _encode_prompt (line 288) | def _encode_prompt(
    method run_safety_checker (line 434) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 448) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 461) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 478) | def check_inputs(
    method prepare_latents (line 525) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 544) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
  class AttentionStore (line 74) | class AttentionStore:
    method get_empty_store (line 76) | def get_empty_store():
    method __call__ (line 79) | def __call__(self, attn, is_cross: bool, place_in_unet: str):
    method between_steps (line 89) | def between_steps(self):
    method get_average_attention (line 93) | def get_average_attention(self):
    method aggregate_attention (line 97) | def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
    method reset (line 109) | def reset(self):
    method __init__ (line 114) | def __init__(self, attn_res):
  class AttendExciteAttnProcessor (line 127) | class AttendExciteAttnProcessor:
    method __init__ (line 128) | def __init__(self, attnstore, place_in_unet):
    method __call__ (line 133) | def __call__(self, attn: Attention, hidden_states, encoder_hidden_stat...
  class StableDiffusionAttendAndExcitePipeline (line 165) | class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualI...
    method __init__ (line 194) | def __init__(
    method enable_vae_slicing (line 237) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 247) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 255) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 282) | def _execution_device(self):
    method _encode_prompt (line 300) | def _encode_prompt(
    method run_safety_checker (line 447) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 462) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 476) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 493) | def check_inputs(
    method prepare_latents (line 567) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method _compute_max_attention_per_index (line 585) | def _compute_max_attention_per_index(
    method _aggregate_and_get_max_attention_per_token (line 607) | def _aggregate_and_get_max_attention_per_token(
    method _compute_loss (line 622) | def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torc...
    method _update_latent (line 629) | def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_siz...
    method _perform_iterative_refinement_step (line 635) | def _perform_iterative_refinement_step(
    method register_attention_control (line 689) | def register_attention_control(self):
    method get_indices (line 708) | def get_indices(self, prompt: str) -> Dict[str, int]:
    method __call__ (line 716) | def __call__(
  class GaussianSmoothing (line 1007) | class GaussianSmoothing(torch.nn.Module):
    method __init__ (line 1020) | def __init__(
    method forward (line 1061) | def forward(self, input):

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
  function preprocess (line 39) | def preprocess(image):
  class StableDiffusionDepth2ImgPipeline (line 60) | class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversi...
    method __init__ (line 90) | def __init__(
    method enable_sequential_cpu_offload (line 135) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 154) | def _execution_device(self):
    method _encode_prompt (line 172) | def _encode_prompt(
    method run_safety_checker (line 319) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 334) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 348) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 366) | def check_inputs(
    method get_timesteps (line 407) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 417) | def prepare_latents(self, image, timestep, batch_size, num_images_per_...
    method prepare_depth_map (line 469) | def prepare_depth_map(self, image, depth_map, batch_size, do_classifie...
    method __call__ (line 512) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
  class DiffEditInversionPipelineOutput (line 50) | class DiffEditInversionPipelineOutput(BaseOutput):
  function auto_corr_loss (line 140) | def auto_corr_loss(hidden_states, generator=None):
  function kl_divergence (line 156) | def kl_divergence(hidden_states):
  function preprocess (line 161) | def preprocess(image):
  function preprocess_mask (line 182) | def preprocess_mask(mask, batch_size: int = 1):
  class StableDiffusionDiffEditPipeline (line 235) | class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversio...
    method __init__ (line 272) | def __init__(
    method enable_vae_slicing (line 366) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 376) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 384) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 394) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 402) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 428) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 458) | def _execution_device(self):
    method _encode_prompt (line 476) | def _encode_prompt(
    method run_safety_checker (line 623) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 638) | def prepare_extra_step_kwargs(self, generator, eta):
    method decode_latents (line 656) | def decode_latents(self, latents):
    method check_inputs (line 669) | def check_inputs(
    method check_source_inputs (line 717) | def check_source_inputs(
    method get_timesteps (line 753) | def get_timesteps(self, num_inference_steps, strength, device):
    method get_inverse_timesteps (line 762) | def get_inverse_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 776) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method prepare_image_latents (line 794) | def prepare_image_latents(self, image, batch_size, dtype, device, gene...
    method get_epsilon (line 837) | def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor...
    method generate_mask (line 855) | def generate_mask(
    method invert (line 1077) | def invert(
    method __call__ (line 1315) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
  class StableDiffusionImageVariationPipeline (line 37) | class StableDiffusionImageVariationPipeline(DiffusionPipeline):
    method __init__ (line 65) | def __init__(
    method enable_sequential_cpu_offload (line 126) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 145) | def _execution_device(self):
    method _encode_image (line 162) | def _encode_image(self, image, device, num_images_per_prompt, do_class...
    method run_safety_checker (line 188) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 203) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 217) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 234) | def check_inputs(self, image, height, width, callback_steps):
    method prepare_latents (line 257) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 275) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
  function preprocess (line 75) | def preprocess(image):
  class StableDiffusionImg2ImgPipeline (line 96) | class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversion...
    method __init__ (line 133) | def __init__(
    method enable_sequential_cpu_offload (line 224) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 250) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 280) | def _execution_device(self):
    method _encode_prompt (line 298) | def _encode_prompt(
    method run_safety_checker (line 444) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 458) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 472) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 489) | def check_inputs(
    method get_timesteps (line 529) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 538) | def prepare_latents(self, image, timestep, batch_size, num_images_per_...
    method __call__ (line 592) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
  function prepare_mask_and_masked_image (line 39) | def prepare_mask_and_masked_image(image, mask, height, width, return_ima...
  class StableDiffusionInpaintPipeline (line 156) | class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversion...
    method __init__ (line 192) | def __init__(
    method enable_sequential_cpu_offload (line 292) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 318) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 348) | def _execution_device(self):
    method _encode_prompt (line 366) | def _encode_prompt(
    method run_safety_checker (line 513) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 528) | def prepare_extra_step_kwargs(self, generator, eta):
    method decode_latents (line 546) | def decode_latents(self, latents):
    method check_inputs (line 559) | def check_inputs(
    method prepare_latents (line 610) | def prepare_latents(
    method prepare_mask_latents (line 664) | def prepare_mask_latents(
    method get_timesteps (line 716) | def get_timesteps(self, num_inference_steps, strength, device):
    method __call__ (line 726) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
  function preprocess_image (line 46) | def preprocess_image(image, batch_size):
  function preprocess_mask (line 56) | def preprocess_mask(mask, batch_size, scale_factor=8):
  class StableDiffusionInpaintPipelineLegacy (line 87) | class StableDiffusionInpaintPipelineLegacy(
    method __init__ (line 127) | def __init__(
    method enable_sequential_cpu_offload (line 218) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 244) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 274) | def _execution_device(self):
    method _encode_prompt (line 292) | def _encode_prompt(
    method run_safety_checker (line 439) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 454) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 468) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 486) | def check_inputs(
    method get_timesteps (line 527) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 536) | def prepare_latents(self, image, timestep, num_images_per_prompt, dtyp...
    method __call__ (line 553) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
  function preprocess (line 45) | def preprocess(image):
  class StableDiffusionInstructPix2PixPipeline (line 66) | class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualI...
    method __init__ (line 102) | def __init__(
    method __call__ (line 145) | def __call__(
    method enable_sequential_cpu_offload (line 416) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 442) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 472) | def _execution_device(self):
    method _encode_prompt (line 489) | def _encode_prompt(
    method run_safety_checker (line 637) | def run_safety_checker(self, image, device, dtype):
    method prepare_extra_step_kwargs (line 652) | def prepare_extra_step_kwargs(self, generator, eta):
    method decode_latents (line 670) | def decode_latents(self, latents):
    method check_inputs (line 683) | def check_inputs(
    method prepare_latents (line 721) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method prepare_image_latents (line 738) | def prepare_image_latents(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
  class ModelWrapper (line 34) | class ModelWrapper:
    method __init__ (line 35) | def __init__(self, model, alphas_cumprod):
    method apply_model (line 39) | def apply_model(self, *args, **kwargs):
  class StableDiffusionKDiffusionPipeline (line 48) | class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInvers...
    method __init__ (line 83) | def __init__(
    method set_scheduler (line 124) | def set_scheduler(self, scheduler_type: str):
    method enable_sequential_cpu_offload (line 130) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 156) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 186) | def _execution_device(self):
    method _encode_prompt (line 204) | def _encode_prompt(
    method run_safety_checker (line 351) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 366) | def decode_latents(self, latents):
    method check_inputs (line 380) | def check_inputs(
    method prepare_latents (line 427) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 440) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
  function preprocess (line 35) | def preprocess(image):
  class StableDiffusionLatentUpscalePipeline (line 56) | class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
    method __init__ (line 79) | def __init__(
    method enable_sequential_cpu_offload (line 99) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 118) | def _execution_device(self):
    method _encode_prompt (line 135) | def _encode_prompt(self, prompt, device, do_classifier_free_guidance, ...
    method decode_latents (line 226) | def decode_latents(self, latents):
    method check_inputs (line 239) | def check_inputs(self, prompt, image, callback_steps):
    method prepare_latents (line 277) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 291) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
  class StableDiffusionModelEditingPipeline (line 58) | class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInve...
    method __init__ (line 90) | def __init__(
    method enable_vae_slicing (line 168) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 178) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 186) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 213) | def _execution_device(self):
    method _encode_prompt (line 231) | def _encode_prompt(
    method run_safety_checker (line 378) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 393) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 407) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 425) | def check_inputs(
    method prepare_latents (line 473) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method edit_model (line 491) | def edit_model(
    method __call__ (line 609) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
  class StableDiffusionPanoramaPipeline (line 53) | class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversio...
    method __init__ (line 86) | def __init__(
    method enable_vae_slicing (line 132) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 142) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 150) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 177) | def _execution_device(self):
    method _encode_prompt (line 195) | def _encode_prompt(
    method run_safety_checker (line 342) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 357) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 371) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 389) | def check_inputs(
    method prepare_latents (line 437) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method get_views (line 454) | def get_views(self, panorama_height, panorama_width, window_size=64, s...
    method __call__ (line 472) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
  class Pix2PixInversionPipelineOutput (line 57) | class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderM...
  function preprocess (line 179) | def preprocess(image):
  function prepare_unet (line 200) | def prepare_unet(unet: UNet2DConditionModel):
  class Pix2PixZeroL2Loss (line 217) | class Pix2PixZeroL2Loss:
    method __init__ (line 218) | def __init__(self):
    method compute_loss (line 221) | def compute_loss(self, predictions, targets):
  class Pix2PixZeroAttnProcessor (line 225) | class Pix2PixZeroAttnProcessor:
    method __init__ (line 229) | def __init__(self, is_pix2pix_zero=False):
    method __call__ (line 234) | def __call__(
  class StableDiffusionPix2PixZeroPipeline (line 280) | class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
    method __init__ (line 318) | def __init__(
    method enable_sequential_cpu_offload (line 367) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 392) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 418) | def _execution_device(self):
    method _encode_prompt (line 436) | def _encode_prompt(
    method run_safety_checker (line 583) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 598) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 612) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 629) | def check_inputs(
    method prepare_latents (line 661) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method generate_caption (line 679) | def generate_caption(self, images):
    method construct_direction (line 698) | def construct_direction(self, embs_source: torch.Tensor, embs_target: ...
    method get_embeds (line 703) | def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch...
    method prepare_image_latents (line 722) | def prepare_image_latents(self, image, batch_size, dtype, device, gene...
    method get_epsilon (line 765) | def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor...
    method auto_corr_loss (line 782) | def auto_corr_loss(self, hidden_states, generator=None):
    method kl_divergence (line 797) | def kl_divergence(self, hidden_states):
    method __call__ (line 804) | def __call__(
    method invert (line 1085) | def invert(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
  class CrossAttnStoreProcessor (line 53) | class CrossAttnStoreProcessor:
    method __init__ (line 54) | def __init__(self):
    method __call__ (line 57) | def __call__(
  class StableDiffusionSAGPipeline (line 93) | class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoad...
    method __init__ (line 122) | def __init__(
    method enable_vae_slicing (line 149) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 159) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 167) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 194) | def _execution_device(self):
    method _encode_prompt (line 212) | def _encode_prompt(
    method run_safety_checker (line 359) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 374) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 388) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 406) | def check_inputs(
    method prepare_latents (line 454) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 473) | def __call__(
    method sag_masking (line 716) | def sag_masking(self, original_latents, attn_map, map_size, t, eps):
    method pred_x0 (line 746) | def pred_x0(self, sample, model_output, timestep):
    method pred_epsilon (line 766) | def pred_epsilon(self, sample, model_output, timestep):
  function gaussian_blur_2d (line 786) | def gaussian_blur_2d(img, kernel_size, sigma):

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
  function preprocess (line 36) | def preprocess(image):
  class StableDiffusionUpscalePipeline (line 57) | class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversion...
    method __init__ (line 84) | def __init__(
    method enable_sequential_cpu_offload (line 130) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 147) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 175) | def _execution_device(self):
    method run_safety_checker (line 193) | def run_safety_checker(self, image, device, dtype):
    method _encode_prompt (line 210) | def _encode_prompt(
    method prepare_extra_step_kwargs (line 357) | def prepare_extra_step_kwargs(self, generator, eta):
    method decode_latents (line 375) | def decode_latents(self, latents):
    method check_inputs (line 388) | def check_inputs(
    method prepare_latents (line 469) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 483) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
  class StableUnCLIPPipeline (line 53) | class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
    method __init__ (line 106) | def __init__(
    method enable_vae_slicing (line 144) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 154) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 161) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 185) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 212) | def _execution_device(self):
    method _encode_prior_prompt (line 230) | def _encode_prior_prompt(
    method _encode_prompt (line 332) | def _encode_prompt(
    method decode_latents (line 479) | def decode_latents(self, latents):
    method prepare_prior_extra_step_kwargs (line 493) | def prepare_prior_extra_step_kwargs(self, generator, eta):
    method prepare_extra_step_kwargs (line 511) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 528) | def check_inputs(
    method prepare_latents (line 589) | def prepare_latents(self, shape, dtype, device, generator, latents, sc...
    method noise_image_embeddings (line 600) | def noise_image_embeddings(
    method __call__ (line 648) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
  class StableUnCLIPImg2ImgPipeline (line 66) | class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoa...
    method __init__ (line 112) | def __init__(
    method enable_vae_slicing (line 146) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 156) | def disable_vae_slicing(self):
    method enable_sequential_cpu_offload (line 163) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 187) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 214) | def _execution_device(self):
    method _encode_prompt (line 232) | def _encode_prompt(
    method _encode_image (line 378) | def _encode_image(
    method decode_latents (line 434) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 448) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 465) | def check_inputs(
    method prepare_latents (line 549) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method noise_image_embeddings (line 567) | def noise_image_embeddings(
    method __call__ (line 615) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/safety_checker.py
  function cosine_distance (line 26) | def cosine_distance(image_embeds, text_embeds):
  class StableDiffusionSafetyChecker (line 32) | class StableDiffusionSafetyChecker(PreTrainedModel):
    method __init__ (line 37) | def __init__(self, config: CLIPConfig):
    method forward (line 50) | def forward(self, clip_input, images):
    method forward_onnx (line 102) | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.Fl...

FILE: diffusers/pipelines/stable_diffusion/safety_checker_flax.py
  function jax_cosine_distance (line 25) | def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
  class FlaxStableDiffusionSafetyCheckerModule (line 31) | class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
    method setup (line 35) | def setup(self):
    method __call__ (line 47) | def __call__(self, clip_input):
  class FlaxStableDiffusionSafetyChecker (line 71) | class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
    method __init__ (line 76) | def __init__(
    method init_weights (line 90) | def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, p...
    method __call__ (line 101) | def __call__(

FILE: diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py
  class StableUnCLIPImageNormalizer (line 24) | class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin):
    method __init__ (line 33) | def __init__(
    method to (line 42) | def to(
    method scale (line 51) | def scale(self, embeds):
    method unscale (line 55) | def unscale(self, embeds):

FILE: diffusers/pipelines/stable_diffusion_safe/__init__.py
  class SafetyConfig (line 13) | class SafetyConfig(object):
  class StableDiffusionSafePipelineOutput (line 45) | class StableDiffusionSafePipelineOutput(BaseOutput):

FILE: diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
  class StableDiffusionPipelineSafe (line 22) | class StableDiffusionPipelineSafe(DiffusionPipeline):
    method __init__ (line 54) | def __init__(
    method safety_concept (line 150) | def safety_concept(self):
    method safety_concept (line 160) | def safety_concept(self, concept):
    method enable_sequential_cpu_offload (line 170) | def enable_sequential_cpu_offload(self):
    method _execution_device (line 189) | def _execution_device(self):
    method _encode_prompt (line 206) | def _encode_prompt(
    method run_safety_checker (line 341) | def run_safety_checker(self, image, device, dtype, enable_safety_guida...
    method decode_latents (line 365) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 379) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 397) | def check_inputs(
    method prepare_latents (line 445) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method perform_safety_guidance (line 462) | def perform_safety_guidance(
    method __call__ (line 505) | def __call__(

FILE: diffusers/pipelines/stable_diffusion_safe/safety_checker.py
  function cosine_distance (line 25) | def cosine_distance(image_embeds, text_embeds):
  class SafeStableDiffusionSafetyChecker (line 31) | class SafeStableDiffusionSafetyChecker(PreTrainedModel):
    method __init__ (line 36) | def __init__(self, config: CLIPConfig):
    method forward (line 49) | def forward(self, clip_input, images):
    method forward_onnx (line 88) | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.Fl...

FILE: diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
  class KarrasVePipeline (line 25) | class KarrasVePipeline(DiffusionPipeline):
    method __init__ (line 44) | def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
    method __call__ (line 49) | def __call__(

FILE: diffusers/pipelines/text_to_video_synthesis/__init__.py
  class TextToVideoSDPipelineOutput (line 11) | class TextToVideoSDPipelineOutput(BaseOutput):

FILE: diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
  function tensor2vid (line 58) | def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5,...
  class TextToVideoSDPipeline (line 76) | class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMix...
    method __init__ (line 97) | def __init__(
    method enable_vae_slicing (line 117) | def enable_vae_slicing(self):
    method disable_vae_slicing (line 127) | def disable_vae_slicing(self):
    method enable_vae_tiling (line 135) | def enable_vae_tiling(self):
    method disable_vae_tiling (line 145) | def disable_vae_tiling(self):
    method enable_sequential_cpu_offload (line 152) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method enable_model_cpu_offload (line 173) | def enable_model_cpu_offload(self, gpu_id=0):
    method _execution_device (line 200) | def _execution_device(self):
    method _encode_prompt (line 218) | def _encode_prompt(
    method decode_latents (line 364) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 388) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 406) | def check_inputs(
    method prepare_latents (line 453) | def prepare_latents(
    method __call__ (line 480) | def __call__(

FILE: diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
  function rearrange_0 (line 18) | def rearrange_0(tensor, f):
  function rearrange_1 (line 24) | def rearrange_1(tensor):
  function rearrange_3 (line 29) | def rearrange_3(tensor, f):
  function rearrange_4 (line 34) | def rearrange_4(tensor):
  class CrossFrameAttnProcessor (line 39) | class CrossFrameAttnProcessor:
    method __init__ (line 49) | def __init__(self, batch_size=2):
    method __call__ (line 52) | def __call__(self, attn, hidden_states, encoder_hidden_states=None, at...
  class TextToVideoPipelineOutput (line 99) | class TextToVideoPipelineOutput(BaseOutput):
  function coords_grid (line 104) | def coords_grid(batch, ht, wd, device):
  function warp_single_latent (line 111) | def warp_single_latent(latent, reference_flow):
  function create_motion_field (line 138) | def create_motion_field(motion_field_strength_x, motion_field_strength_y...
  function create_motion_field_and_warp_latents (line 161) | def create_motion_field_and_warp_latents(motion_field_strength_x, motion...
  class TextToVideoZeroPipeline (line 188) | class TextToVideoZeroPipeline(StableDiffusionPipeline):
    method __init__ (line 216) | def __init__(
    method forward_loop (line 232) | def forward_loop(self, x_t0, t0, t1, generator):
    method backward_loop (line 250) | def backward_loop(
    method __call__ (line 320) | def __call__(

FILE: diffusers/pipelines/unclip/pipeline_unclip.py
  class UnCLIPPipeline (line 34) | class UnCLIPPipeline(DiffusionPipeline):
    method __init__ (line 78) | def __init__(
    method prepare_latents (line 106) | def prepare_latents(self, shape, dtype, device, generator, latents, sc...
    method _encode_prompt (line 117) | def _encode_prompt(
    method enable_sequential_cpu_offload (line 208) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 234) | def _execution_device(self):
    method __call__ (line 252) | def __call__(

FILE: diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
  class UnCLIPImageVariationPipeline (line 38) | class UnCLIPImageVariationPipeline(DiffusionPipeline):
    method __init__ (line 84) | def __init__(
    method prepare_latents (line 113) | def prepare_latents(self, shape, dtype, device, generator, latents, sc...
    method _encode_prompt (line 124) | def _encode_prompt(self, prompt, device, num_images_per_prompt, do_cla...
    method _encode_image (line 187) | def _encode_image(self, image, device, num_images_per_prompt, image_em...
    method enable_sequential_cpu_offload (line 201) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 227) | def _execution_device(self):
    method __call__ (line 245) | def __call__(

FILE: diffusers/pipelines/unclip/text_proj.py
  class UnCLIPTextProjModel (line 22) | class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
    method __init__ (line 31) | def __init__(
    method forward (line 55) | def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hid...

FILE: diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
  function get_down_block (line 27) | def get_down_block(
  function get_up_block (line 86) | def get_up_block(
  class UNetFlatConditionModel (line 146) | class UNetFlatConditionModel(ModelMixin, ConfigMixin):
    method __init__ (line 225) | def __init__(
    method attn_processors (line 585) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 608) | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict...
    method set_default_attn_processor (line 638) | def set_default_attn_processor(self):
    method set_attention_slice (line 644) | def set_attention_slice(self, slice_size):
    method _set_gradient_checkpointing (line 709) | def _set_gradient_checkpointing(self, module, value=False):
    method forward (line 713) | def forward(
  class LinearMultiDim (line 902) | class LinearMultiDim(nn.Linear):
    method __init__ (line 903) | def __init__(self, in_features, out_features=None, second_dim=4, *args...
    method forward (line 912) | def forward(self, input_tensor, *args, **kwargs):
  class ResnetBlockFlat (line 921) | class ResnetBlockFlat(nn.Module):
    method __init__ (line 922) | def __init__(
    method forward (line 982) | def forward(self, input_tensor, temb):
  class DownBlockFlat (line 1016) | class DownBlockFlat(nn.Module):
    method __init__ (line 1017) | def __init__(
    method forward (line 1068) | def forward(self, hidden_states, temb=None):
  class CrossAttnDownBlockFlat (line 1103) | class CrossAttnDownBlockFlat(nn.Module):
    method __init__ (line 1104) | def __init__(
    method forward (line 1190) | def forward(
  class UpBlockFlat (line 1250) | class UpBlockFlat(nn.Module):
    method __init__ (line 1251) | def __init__(
    method forward (line 1298) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, u...
  class CrossAttnUpBlockFlat (line 1332) | class CrossAttnUpBlockFlat(nn.Module):
    method __init__ (line 1333) | def __init__(
    method forward (line 1415) | def forward(
  class UNetMidBlockFlatCrossAttn (line 1481) | class UNetMidBlockFlatCrossAttn(nn.Module):
    method __init__ (line 1482) | def __init__(
    method forward (line 1566) | def forward(
  class UNetMidBlockFlatSimpleCrossAttn (line 1583) | class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
    method __init__ (line 1584) | def __init__(
    method forward (line 1668) | def forward(

FILE: diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
  class VersatileDiffusionPipeline (line 20) | class VersatileDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 57) | def __init__(
    method image_variation (line 83) | def image_variation(
    method text_to_image (line 199) | def text_to_image(
    method dual_guided (line 311) | def dual_guided(

FILE: diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
  class VersatileDiffusionDualGuidedPipeline (line 41) | class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
    method __init__ (line 70) | def __init__(
    method remove_unused_weights (line 101) | def remove_unused_weights(self):
    method _convert_to_dual_attention (line 104) | def _convert_to_dual_attention(self):
    method _revert_dual_attention (line 138) | def _revert_dual_attention(self):
    method enable_sequential_cpu_offload (line 151) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 170) | def _execution_device(self):
    method _encode_text_prompt (line 187) | def _encode_text_prompt(self, prompt, device, num_images_per_prompt, d...
    method _encode_image_prompt (line 278) | def _encode_image_prompt(self, prompt, device, num_images_per_prompt, ...
    method decode_latents (line 334) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 348) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 365) | def check_inputs(self, prompt, image, height, width, callback_steps):
    method prepare_latents (line 383) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method set_transformer_params (line 400) | def set_transformer_params(self, mix_ratio: float = 0.5, condition_typ...
    method __call__ (line 414) | def __call__(

FILE: diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
  class VersatileDiffusionImageVariationPipeline (line 35) | class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
    method __init__ (line 59) | def __init__(
    method enable_sequential_cpu_offload (line 78) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 97) | def _execution_device(self):
    method _encode_prompt (line 114) | def _encode_prompt(self, prompt, device, num_images_per_prompt, do_cla...
    method decode_latents (line 194) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 208) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 226) | def check_inputs(self, image, height, width, callback_steps):
    method prepare_latents (line 249) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 267) | def __call__(

FILE: diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
  class VersatileDiffusionTextToImagePipeline (line 34) | class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
    method __init__ (line 62) | def __init__(
    method _swap_unet_attention_blocks (line 86) | def _swap_unet_attention_blocks(self):
    method remove_unused_weights (line 99) | def remove_unused_weights(self):
    method enable_sequential_cpu_offload (line 102) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 121) | def _execution_device(self):
    method _encode_prompt (line 138) | def _encode_prompt(self, prompt, device, num_images_per_prompt, do_cla...
    method decode_latents (line 251) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 265) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 283) | def check_inputs(
    method prepare_latents (line 331) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
    method __call__ (line 349) | def __call__(

FILE: diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
  class LearnedClassifierFreeSamplingEmbeddings (line 30) | class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
    method __init__ (line 36) | def __init__(self, learnable: bool, hidden_size: Optional[int] = None,...
  class VQDiffusionPipeline (line 52) | class VQDiffusionPipeline(DiffusionPipeline):
    method __init__ (line 83) | def __init__(
    method _encode_prompt (line 103) | def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_...
    method __call__ (line 167) | def __call__(
    method truncate (line 310) | def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: floa...

FILE: diffusers/schedulers/scheduling_ddim.py
  class DDIMSchedulerOutput (line 32) | class DDIMSchedulerOutput(BaseOutput):
  function betas_for_alpha_bar (line 50) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class DDIMScheduler (line 79) | class DDIMScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 131) | def __init__(
    method scale_model_input (line 180) | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optio...
    method _get_variance (line 194) | def _get_variance(self, timestep, prev_timestep):
    method _threshold_sample (line 205) | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatT...
    method set_timesteps (line 239) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method step (line 270) | def step(
    method add_noise (line 393) | def add_noise(
    method get_velocity (line 417) | def get_velocity(
    method __len__ (line 437) | def __len__(self):

FILE: diffusers/schedulers/scheduling_ddim_flax.py
  class DDIMSchedulerState (line 36) | class DDIMSchedulerState:
    method create (line 46) | def create(
  class FlaxDDIMSchedulerOutput (line 62) | class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
  class FlaxDDIMScheduler (line 66) | class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
    method has_state (line 109) | def has_state(self):
    method __init__ (line 113) | def __init__(
    method create_state (line 127) | def create_state(self, common: Optional[CommonSchedulerState] = None) ...
    method scale_model_input (line 151) | def scale_model_input(
    method set_timesteps (line 165) | def set_timesteps(
    method _get_variance (line 187) | def _get_variance(self, state: DDIMSchedulerState, timestep, prev_time...
    method step (line 199) | def step(
    method add_noise (line 286) | def add_noise(
    method get_velocity (line 295) | def get_velocity(
    method __len__ (line 304) | def __len__(self):

FILE: diffusers/schedulers/scheduling_ddim_inverse.py
  class DDIMSchedulerOutput (line 31) | class DDIMSchedulerOutput(BaseOutput):
  function betas_for_alpha_bar (line 49) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class DDIMInverseScheduler (line 78) | class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 119) | def __init__(
    method scale_model_input (line 172) | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optio...
    method set_timesteps (line 186) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method step (line 210) | def step(
    method __len__ (line 267) | def __len__(self):

FILE: diffusers/schedulers/scheduling_ddpm.py
  class DDPMSchedulerOutput (line 30) | class DDPMSchedulerOutput(BaseOutput):
  function betas_for_alpha_bar (line 47) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class DDPMScheduler (line 76) | class DDPMScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 123) | def __init__(
    method scale_model_input (line 171) | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optio...
    method set_timesteps (line 185) | def set_timesteps(
    method _get_variance (line 238) | def _get_variance(self, t, predicted_variance=None, variance_type=None):
    method _threshold_sample (line 278) | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatT...
    method step (line 312) | def step(
    method add_noise (line 408) | def add_noise(
    method get_velocity (line 431) | def get_velocity(
    method __len__ (line 451) | def __len__(self):
    method previous_timestep (line 454) | def previous_timestep(self, timestep):

FILE: diffusers/schedulers/scheduling_ddpm_flax.py
  class DDPMSchedulerState (line 36) | class DDPMSchedulerState:
    method create (line 45) | def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.nd...
  class FlaxDDPMSchedulerOutput (line 50) | class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
  class FlaxDDPMScheduler (line 54) | class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
    method has_state (line 92) | def has_state(self):
    method __init__ (line 96) | def __init__(
    method create_state (line 110) | def create_state(self, common: Optional[CommonSchedulerState] = None) ...
    method scale_model_input (line 125) | def scale_model_input(
    method set_timesteps (line 139) | def set_timesteps(
    method _get_variance (line 162) | def _get_variance(self, state: DDPMSchedulerState, t, predicted_varian...
    method step (line 195) | def step(
    method add_noise (line 280) | def add_noise(
    method get_velocity (line 289) | def get_velocity(
    method __len__ (line 298) | def __len__(self):

FILE: diffusers/schedulers/scheduling_deis_multistep.py
  function betas_for_alpha_bar (line 29) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class DEISMultistepScheduler (line 58) | class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 113) | def __init__(
    method set_timesteps (line 174) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method _threshold_sample (line 206) | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatT...
    method convert_model_output (line 240) | def convert_model_output(
    method deis_first_order_update (line 278) | def deis_first_order_update(
    method multistep_deis_second_order_update (line 308) | def multistep_deis_second_order_update(
    method multistep_deis_third_order_update (line 350) | def multistep_deis_third_order_update(
    method step (line 407) | def step(
    method scale_model_input (line 475) | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs...
    method add_noise (line 489) | def add_noise(
    method __len__ (line 512) | def __len__(self):

FILE: diffusers/schedulers/scheduling_dpmsolver_multistep.py
  function betas_for_alpha_bar (line 29) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class DPMSolverMultistepScheduler (line 58) | class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 143) | def __init__(
    method set_timesteps (line 208) | def set_timesteps(self, num_inference_steps: int = None, device: Union...
    method _threshold_sample (line 250) | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatT...
    method _sigma_to_t (line 285) | def _sigma_to_t(self, sigma, log_sigmas):
    method _convert_to_karras (line 309) | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inferen...
    method convert_model_output (line 322) | def convert_model_output(
    method dpm_solver_first_order_update (line 397) | def dpm_solver_first_order_update(
    method multistep_dpm_solver_second_order_update (line 444) | def multistep_dpm_solver_second_order_update(
    method multistep_dpm_solver_third_order_update (line 536) | def multistep_dpm_solver_third_order_update(
    method step (line 591) | def step(
    method scale_model_input (line 669) | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs...
    method add_noise (line 683) | def add_noise(
    method __len__ (line 706) | def __len__(self):

FILE: diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
  class DPMSolverMultistepSchedulerState (line 35) | class DPMSolverMultistepSchedulerState:
    method create (line 53) | def create(
  class FlaxDPMSolverMultistepSchedulerOutput (line 73) | class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput):
  class FlaxDPMSolverMultistepScheduler (line 77) | class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
    method has_state (line 147) | def has_state(self):
    method __init__ (line 151) | def __init__(
    method create_state (line 170) | def create_state(self, common: Optional[CommonSchedulerState] = None) ...
    method set_timesteps (line 199) | def set_timesteps(
    method convert_model_output (line 236) | def convert_model_output(
    method dpm_solver_first_order_update (line 306) | def dpm_solver_first_order_update(
    method multistep_dpm_solver_second_order_update (line 341) | def multistep_dpm_solver_second_order_update(
    method multistep_dpm_solver_third_order_update (line 401) | def multistep_dpm_solver_third_order_update(
    method step (line 457) | def step(
    method scale_model_input (line 594) | def scale_model_input(
    method add_noise (line 612) | def add_noise(
    method __len__ (line 621) | def __len__(self):

FILE: diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
  function betas_for_alpha_bar (line 29) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class DPMSolverMultistepInverseScheduler (line 58) | class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 132) | def __init__(
    method set_timesteps (line 197) | def set_timesteps(self, num_inference_steps: int = None, device: Union...
    method _threshold_sample (line 237) | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatT...
    method _sigma_to_t (line 272) | def _sigma_to_t(self, sigma, log_sigmas):
    method _convert_to_karras (line 296) | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inferen...
    method convert_model_output (line 310) | def convert_model_output(
    method dpm_solver_first_order_update (line 386) | def dpm_solver_first_order_update(
    method multistep_dpm_solver_second_order_update (line 434) | def multistep_dpm_solver_second_order_update(
    method multistep_dpm_solver_third_order_update (line 527) | def multistep_dpm_solver_third_order_update(
    method step (line 582) | def step(
    method scale_model_input (line 663) | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs...
    method add_noise (line 677) | def add_noise(
    method __len__ (line 700) | def __len__(self):

FILE: diffusers/schedulers/scheduling_dpmsolver_sde.py
  class BatchedBrownianTree (line 26) | class BatchedBrownianTree:
    method __init__ (line 29) | def __init__(self, x, t0, t1, seed=None, **kwargs):
    method sort (line 44) | def sort(a, b):
    method __call__ (line 47) | def __call__(self, t0, t1):
  class BrownianTreeNoiseSampler (line 53) | class BrownianTreeNoiseSampler:
    method __init__ (line 68) | def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambd...
    method __call__ (line 73) | def __call__(self, sigma, sigma_next):
  function betas_for_alpha_bar (line 79) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class DPMSolverSDEScheduler (line 108) | class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 142) | def __init__(
    method index_for_timestep (line 178) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method scale_model_input (line 190) | def scale_model_input(
    method set_timesteps (line 210) | def set_timesteps(
    method _second_order_timesteps (line 263) | def _second_order_timesteps(self, sigmas, log_sigmas):
    method _sigma_to_t (line 279) | def _sigma_to_t(self, sigma, log_sigmas):
    method _convert_to_karras (line 303) | def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.Fl...
    method state_in_first_order (line 317) | def state_in_first_order(self):
    method step (line 320) | def step(
    method add_noise (line 421) | def add_noise(
    method __len__ (line 446) | def __len__(self):

FILE: diffusers/schedulers/scheduling_dpmsolver_singlestep.py
  function betas_for_alpha_bar (line 28) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class DPMSolverSinglestepScheduler (line 57) | class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 134) | def __init__(
    method get_order_list (line 197) | def get_order_list(self, num_inference_steps: int) -> List[int]:
    method set_timesteps (line 231) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method _threshold_sample (line 257) | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatT...
    method convert_model_output (line 291) | def convert_model_output(
    method dpm_solver_first_order_update (line 357) | def dpm_solver_first_order_update(
    method singlestep_dpm_solver_second_order_update (line 389) | def singlestep_dpm_solver_second_order_update(
    method singlestep_dpm_solver_third_order_update (line 450) | def singlestep_dpm_solver_third_order_update(
    method singlestep_dpm_solver_update (line 521) | def singlestep_dpm_solver_update(
    method step (line 558) | def step(
    method scale_model_input (line 614) | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs...
    method add_noise (line 628) | def add_noise(
    method __len__ (line 651) | def __len__(self):

FILE: diffusers/schedulers/scheduling_euler_ancestral_discrete.py
  class EulerAncestralDiscreteSchedulerOutput (line 32) | class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
  function betas_for_alpha_bar (line 50) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class EulerAncestralDiscreteScheduler (line 79) | class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 109) | def __init__(
    method scale_model_input (line 149) | def scale_model_input(
    method set_timesteps (line 170) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method step (line 193) | def step(
    method add_noise (line 283) | def add_noise(
    method __len__ (line 308) | def __len__(self):

FILE: diffusers/schedulers/scheduling_euler_discrete.py
  class EulerDiscreteSchedulerOutput (line 32) | class EulerDiscreteSchedulerOutput(BaseOutput):
  function betas_for_alpha_bar (line 50) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class EulerDiscreteScheduler (line 79) | class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 116) | def __init__(
    method scale_model_input (line 159) | def scale_model_input(
    method set_timesteps (line 182) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method _sigma_to_t (line 220) | def _sigma_to_t(self, sigma, log_sigmas):
    method _convert_to_karras (line 244) | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inferen...
    method step (line 257) | def step(
    method add_noise (line 356) | def add_noise(
    method __len__ (line 381) | def __len__(self):

FILE: diffusers/schedulers/scheduling_heun_discrete.py
  function betas_for_alpha_bar (line 26) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class HeunDiscreteScheduler (line 55) | class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 87) | def __init__(
    method index_for_timestep (line 119) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method scale_model_input (line 131) | def scale_model_input(
    method set_timesteps (line 150) | def set_timesteps(
    method _sigma_to_t (line 200) | def _sigma_to_t(self, sigma, log_sigmas):
    method _convert_to_karras (line 224) | def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inferen...
    method state_in_first_order (line 238) | def state_in_first_order(self):
    method step (line 241) | def step(
    method add_noise (line 325) | def add_noise(
    method __len__ (line 350) | def __len__(self):

FILE: diffusers/schedulers/scheduling_ipndm.py
  class IPNDMScheduler (line 25) | class IPNDMScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 44) | def __init__(
    method set_timesteps (line 61) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method step (line 85) | def step(
    method scale_model_input (line 135) | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs...
    method _get_prev_sample (line 148) | def _get_prev_sample(self, sample, timestep_index, prev_timestep_index...
    method __len__ (line 160) | def __len__(self):

FILE: diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
  function betas_for_alpha_bar (line 27) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class KDPM2AncestralDiscreteScheduler (line 56) | class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 87) | def __init__(
    method index_for_timestep (line 118) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method scale_model_input (line 130) | def scale_model_input(
    method set_timesteps (line 153) | def set_timesteps(
    method sigma_to_t (line 216) | def sigma_to_t(self, sigma):
    method state_in_first_order (line 240) | def state_in_first_order(self):
    method step (line 243) | def step(
    method add_noise (line 332) | def add_noise(
    method __len__ (line 357) | def __len__(self):

FILE: diffusers/schedulers/scheduling_k_dpm_2_discrete.py
  function betas_for_alpha_bar (line 26) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torc...
  class KDPM2DiscreteScheduler (line 55) | class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 86) | def __init__(
    method index_for_timestep (line 117) | def index_for_timestep(self, timestep, schedule_timesteps=None):
    method scale_model_input (line 129) | def scale_model_input(
    method set_timesteps (line 152) | def set_timesteps(
    method sigma_to_t (line 205) | def sigma_to_t(self, sigma):
    method state_in_first_order (line 229) | def state_in_first_order(self):
    method step (line 232) | def step(
    method add_noise (line 313) | def add_noise(
    method __len__ (line 338) | def __len__(self):

FILE: diffusers/schedulers/scheduling_karras_ve.py
  class KarrasVeOutput (line 28) | class KarrasVeOutput(BaseOutput):
  class KarrasVeScheduler (line 48) | class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 83) | def __init__(
    method scale_model_input (line 100) | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optio...
    method set_timesteps (line 114) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method add_noise_to_input (line 135) | def add_noise_to_input(
    method step (line 156) | def step(
    method step_correct (line 194) | def step_correct(
    method add_noise (line 231) | def add_noise(self, original_samples, noise, timesteps):

FILE: diffusers/schedulers/scheduling_karras_ve_flax.py
  class KarrasVeSchedulerState (line 29) | class KarrasVeSchedulerState:
    method create (line 36) | def create(cls):
  class FlaxKarrasVeOutput (line 41) | class FlaxKarrasVeOutput(BaseOutput):
  class FlaxKarrasVeScheduler (line 59) | class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
    method has_state (line 91) | def has_state(self):
    method __init__ (line 95) | def __init__(
    method create_state (line 106) | def create_state(self):
    method set_timesteps (line 109) | def set_timesteps(
    method add_noise_to_input (line 137) | def add_noise_to_input(
    method step (line 163) | def step(
    method step_correct (line 199) | def step_correct(
    method add_noise (line 236) | def add_noise(self, state: KarrasVeSchedulerState, original_samples, n...

FILE: diffusers/schedulers/scheduling_lms_discrete.py
  class LMSDiscreteSchedulerOutput (line 30) | class LMSDiscreteSchedulerOutput(BaseOutput):
  function betas_for_alpha_bar (line 48) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class LMSDiscreteScheduler (line 77) | class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 111) | def __init__(
    method scale_model_input (line 153) | def scale_model_input(
    method get_lms_coefficient (line 174) | def get_lms_coefficient(self, order, t, current_order):
    method set_timesteps (line 196) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method _sigma_to_t (line 230) | def _sigma_to_t(self, sigma, log_sigmas):
    method _convert_to_karras (line 254) | def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.Fl...
    method step (line 267) | def step(
    method add_noise (line 338) | def add_noise(
    method __len__ (line 363) | def __len__(self):

FILE: diffusers/schedulers/scheduling_lms_discrete_flax.py
  class LMSDiscreteSchedulerState (line 33) | class LMSDiscreteSchedulerState:
    method create (line 46) | def create(
  class FlaxLMSSchedulerOutput (line 53) | class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
  class FlaxLMSDiscreteScheduler (line 57) | class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
    method has_state (line 90) | def has_state(self):
    method __init__ (line 94) | def __init__(
    method create_state (line 106) | def create_state(self, common: Optional[CommonSchedulerState] = None) ...
    method scale_model_input (line 123) | def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: ...
    method get_lms_coefficient (line 145) | def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order,...
    method set_timesteps (line 167) | def set_timesteps(
    method step (line 203) | def step(
    method add_noise (line 268) | def add_noise(
    method __len__ (line 282) | def __len__(self):

FILE: diffusers/schedulers/scheduling_pndm.py
  function betas_for_alpha_bar (line 28) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
  class PNDMScheduler (line 57) | class PNDMScheduler(SchedulerMixin, ConfigMixin):
    method __init__ (line 99) | def __init__(
    method set_timesteps (line 152) | def set_timesteps(self, num_inference_steps: int, device: Union[str, t...
    method step (line 192) | def step(
    method step_prk (line 223) | def step_prk(
    method step_plms (line 278) | def step_plms(
    method scale_model_input (line 345) | def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs...
    method _get_prev_sample (line 358) | def _get_prev_sample(self, sample, timestep, prev_timestep, model_outp...
    method add_noise (line 402) | def add_noise(
    method __len__ (line 425) | def __len__(self):

FILE: diffusers/schedulers/scheduling_pndm_flax.py
  class PNDMSchedulerState (line 35) | class PNDMSchedulerState:
    method create (line 53) | def create(
  class FlaxPNDMSchedulerOutput (line 69) | class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput):
  class FlaxPNDMScheduler (line 73) | class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
    method has_state (line 119) | def has_state(self):
    method __init__ (line 123) | def __init__(
    method create_state (line 143) | def create_state(self, common: Optional[CommonSchedulerState] = None) ...
    method set_timesteps (line 167) | def set_timesteps(self, state: PNDMSchedulerState, num
Condensed preview — 439 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (9,382K chars).
[
  {
    "path": ".gitignore",
    "chars": 2005,
    "preview": "# Initially taken from Github's Python gitignore file\n\nrun2\nrun\npretrained\n\n# Byte-compiled / optimized / DLL files\n__py"
  },
  {
    "path": "LICENSE",
    "chars": 11346,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 8251,
    "preview": "# Diff-Pruning: Structural Pruning for Diffusion Models\n\n<div align=\"center\">\n<img src=\"assets/framework.png\" width=\"80%"
  },
  {
    "path": "ddpm_exp/.gitignore",
    "chars": 40,
    "preview": ".vscode\n__pycache__\n*.log\nrun\ndata\n*.png"
  },
  {
    "path": "ddpm_exp/LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) 2020 Jiaming Song\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "ddpm_exp/README.md",
    "chars": 4781,
    "preview": "# Denoising Diffusion Implicit Models (DDIM)\n\n[Jiaming Song](http://tsong.me), [Chenlin Meng](http://cs.stanford.edu/~ch"
  },
  {
    "path": "ddpm_exp/calc_fid.py",
    "chars": 508,
    "preview": "from cleanfid import fid\nimport argparse\nparser = argparse.ArgumentParser(description=globals()[\"__doc__\"])\nparser.add_a"
  },
  {
    "path": "ddpm_exp/compute_flops.py",
    "chars": 637,
    "preview": "import torch\nimport random, os\nimport argparse\nfrom PIL import Image\nimport torchvision\nimport numpy as np\nimport pytorc"
  },
  {
    "path": "ddpm_exp/compute_pruned_ssim_curve.py",
    "chars": 938,
    "preview": "import pytorch_msssim \nimport os\nimport torch\nfrom PIL import Image\nimport torchvision\n\nbase_folder_name = 'run/prune_ss"
  },
  {
    "path": "ddpm_exp/compute_ssim.py",
    "chars": 1601,
    "preview": "import torch\nimport random, os\nimport argparse\nfrom PIL import Image\nimport torchvision\nimport numpy as np\nimport pytorc"
  },
  {
    "path": "ddpm_exp/compute_ssim_vis.py",
    "chars": 840,
    "preview": "import torch\nimport random, os\nimport argparse\nfrom PIL import Image\nimport torchvision\nimport numpy as np\nimport pytorc"
  },
  {
    "path": "ddpm_exp/configs/bedroom.yml",
    "chars": 910,
    "preview": "data:\n    dataset: \"LSUN\"\n    category: \"bedroom\"\n    image_size: 256\n    channels: 3\n    logit_transform: false\n    uni"
  },
  {
    "path": "ddpm_exp/configs/celeba.yml",
    "chars": 909,
    "preview": "data:\n    dataset: \"CELEBA\"\n    image_size: 64\n    channels: 3\n    logit_transform: false\n    uniform_dequantization: fa"
  },
  {
    "path": "ddpm_exp/configs/church.yml",
    "chars": 921,
    "preview": "data:\n    dataset: \"LSUN\"\n    category: \"church_outdoor\"\n    image_size: 256\n    channels: 3\n    logit_transform: false\n"
  },
  {
    "path": "ddpm_exp/configs/cifar10.yml",
    "chars": 901,
    "preview": "data:\n    dataset: \"CIFAR10\"\n    image_size: 32\n    channels: 3\n    logit_transform: false\n    uniform_dequantization: f"
  },
  {
    "path": "ddpm_exp/configs/cifar10_pruning.yml",
    "chars": 902,
    "preview": "data:\n    dataset: \"CIFAR10\"\n    image_size: 32\n    channels: 3\n    logit_transform: false\n    uniform_dequantization: f"
  },
  {
    "path": "ddpm_exp/datasets/__init__.py",
    "chars": 6788,
    "preview": "import os\nimport torch\nimport numbers\nimport torchvision.transforms as transforms\nimport torchvision.transforms.function"
  },
  {
    "path": "ddpm_exp/datasets/celeba.py",
    "chars": 7746,
    "preview": "import torch\nimport os\nimport PIL\nfrom .vision import VisionDataset\nfrom .utils import download_file_from_google_drive, "
  },
  {
    "path": "ddpm_exp/datasets/ffhq.py",
    "chars": 1055,
    "preview": "from io import BytesIO\n\nimport lmdb\nfrom PIL import Image\nfrom torch.utils.data import Dataset\n\n\nclass FFHQ(Dataset):\n  "
  },
  {
    "path": "ddpm_exp/datasets/lsun.py",
    "chars": 5503,
    "preview": "from .vision import VisionDataset\nfrom PIL import Image\nimport os\nimport os.path\nimport io\nfrom collections.abc import I"
  },
  {
    "path": "ddpm_exp/datasets/utils.py",
    "chars": 5665,
    "preview": "import os\nimport os.path\nimport hashlib\nimport errno\nfrom torch.utils.model_zoo import tqdm\n\n\ndef gen_bar_updater():\n   "
  },
  {
    "path": "ddpm_exp/datasets/vision.py",
    "chars": 3266,
    "preview": "import os\nimport torch\nimport torch.utils.data as data\n\n\nclass VisionDataset(data.Dataset):\n    _repr_indent = 4\n\n    de"
  },
  {
    "path": "ddpm_exp/draw_ssim_pruned_curve.py",
    "chars": 43152,
    "preview": "import matplotlib.pyplot as plt\nimport numpy as np\n\nplt.style.use('seaborn-whitegrid')\n\nssim = [0.7881933450698853, 0.80"
  },
  {
    "path": "ddpm_exp/extract_cifar10.py",
    "chars": 621,
    "preview": "import os\nimport torchvision\nfrom torchvision.datasets import CIFAR10\nfrom tqdm import tqdm\n\n# Define the path to the fo"
  },
  {
    "path": "ddpm_exp/fid_score.py",
    "chars": 13736,
    "preview": "\"\"\"Calculates the Frechet Inception Distance (FID) to evalulate GANs\n\nThe FID metric calculates the distance between two"
  },
  {
    "path": "ddpm_exp/finetune.py",
    "chars": 9259,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/finetune_simple.py",
    "chars": 13645,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/functions/__init__.py",
    "chars": 727,
    "preview": "import torch.optim as optim\n\n\ndef get_optimizer(config, parameters):\n    if config.optim.optimizer == 'Adam':\n        re"
  },
  {
    "path": "ddpm_exp/functions/ckpt_util.py",
    "chars": 3291,
    "preview": "import os, hashlib\nimport requests\nfrom tqdm import tqdm\n\nURL_MAP = {\n    \"cifar10\": \"https://heibox.uni-heidelberg.de/f"
  },
  {
    "path": "ddpm_exp/functions/denoising.py",
    "chars": 2363,
    "preview": "import torch\n\n\ndef compute_alpha(beta, t):\n    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)\n    a = ("
  },
  {
    "path": "ddpm_exp/functions/losses.py",
    "chars": 1373,
    "preview": "import torch\n\n\ndef noise_estimation_loss(model,\n                          x0: torch.Tensor,\n                          t:"
  },
  {
    "path": "ddpm_exp/inception.py",
    "chars": 12759,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\n\ntry:\n    from torchvision.models."
  },
  {
    "path": "ddpm_exp/main.py",
    "chars": 7498,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/models/diffusion.py",
    "chars": 12846,
    "preview": "import math\nimport torch\nimport torch.nn as nn\n\n\ndef get_timestep_embedding(timesteps, embedding_dim):\n    \"\"\"\n    This "
  },
  {
    "path": "ddpm_exp/models/ema.py",
    "chars": 2785,
    "preview": "import torch.nn as nn\n\n\nclass EMAHelper(object):\n    def __init__(self, mu=0.999):\n        self.mu = mu\n        self.sha"
  },
  {
    "path": "ddpm_exp/prune.py",
    "chars": 10016,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/prune_kd.py",
    "chars": 13970,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/prune_ssim.py",
    "chars": 10386,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/prune_test.py",
    "chars": 10886,
    "preview": "import argparse\nimport traceback\nimport shutil\nimport logging\nimport yaml\nimport sys\nimport os\nimport torch\nimport numpy"
  },
  {
    "path": "ddpm_exp/runners/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ddpm_exp/runners/diffusion.py",
    "chars": 20498,
    "preview": "import os\nimport logging\nimport time\nimport glob\n\nimport numpy as np\nimport tqdm\nimport torch\nimport torch.utils.data as"
  },
  {
    "path": "ddpm_exp/runners/diffusion_simple.py",
    "chars": 16557,
    "preview": "import os\nimport logging\nimport time\nimport glob\n\nimport numpy as np\nimport tqdm\nimport torch\nimport torch.utils.data as"
  },
  {
    "path": "ddpm_exp/scripts/finetune_bedroom_ddpm.sh",
    "chars": 441,
    "preview": "python -m torch.distributed.launch --nproc_per_node=6 --master_port 22223 --use_env finetune.py \\\n--config bedroom.yml \\"
  },
  {
    "path": "ddpm_exp/scripts/finetune_celeba_ddpm.sh",
    "chars": 281,
    "preview": "python finetune.py \\\n--config celeba.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune_final/celeba_T=$1_finet"
  },
  {
    "path": "ddpm_exp/scripts/finetune_celeba_ddpm_kd.sh",
    "chars": 292,
    "preview": "python finetune.py \\\n--config celeba.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune_v2/celeba_ddpm_$1_0.3_f"
  },
  {
    "path": "ddpm_exp/scripts/finetune_church_ddpm.sh",
    "chars": 358,
    "preview": "python -m torch.distributed.launch --nproc_per_node=4 --master_port 22223 --use_env finetune.py \\\n--config church.yml \\\n"
  },
  {
    "path": "ddpm_exp/scripts/finetune_cifar_ddpm.sh",
    "chars": 291,
    "preview": "python finetune.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune_v3/cifar10_ddpm_$1_fin"
  },
  {
    "path": "ddpm_exp/scripts/finetune_cifar_ddpm_kd.sh",
    "chars": 296,
    "preview": "python finetune.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune_v2/cifar10_ddpm_$1_0.3"
  },
  {
    "path": "ddpm_exp/scripts/finetune_cifar_ddpm_random.sh",
    "chars": 387,
    "preview": "python -m torch.distributed.launch --nproc_per_node=2 --master_port 22223 --use_env finetune.py \\\n--config cifar10.yml \\"
  },
  {
    "path": "ddpm_exp/scripts/finetune_cifar_ddpm_taylor.sh",
    "chars": 325,
    "preview": "python finetune.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune/cifar10_pruned_taylor_"
  },
  {
    "path": "ddpm_exp/scripts/old/run_bedroom_sample_pratrained.sh",
    "chars": 219,
    "preview": "python prune.py \\\n--config bedroom.yml \\\n--exp run/ddim_bedroom_official \\\n--sample \\\n--use_pretrained \\\n--timesteps 50 "
  },
  {
    "path": "ddpm_exp/scripts/old/run_celeba_pruning_scratch.sh",
    "chars": 251,
    "preview": "python prune.py \\\n--config celeba.yml \\\n--exp run/ddim_celeba_pruning_reinit \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc "
  },
  {
    "path": "ddpm_exp/scripts/old/run_celeba_pruning_taylor.sh",
    "chars": 251,
    "preview": "python prune.py \\\n--config celeba.yml \\\n--exp run/ddim_celeba_pruning_taylor \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc "
  },
  {
    "path": "ddpm_exp/scripts/old/run_celeba_sample_pratrained.sh",
    "chars": 291,
    "preview": "python prune.py \\\n--config celeba.yml \\\n--exp run/ddim_celeba_official \\\n--sample \\\n--use_pretrained \\\n--timesteps 100 \\"
  },
  {
    "path": "ddpm_exp/scripts/old/run_church_pruning_taylor.sh",
    "chars": 226,
    "preview": "python prune.py \\\n--config church.yml \\\n--exp run/ddim_church_pruning_taylor \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc "
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_first_order_taylor.sh",
    "chars": 252,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_first_order_taylor \\\n--timesteps 100 \\\n--eta 0 \\"
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_magnitude.sh",
    "chars": 234,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_magnitude \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n-"
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_random.sh",
    "chars": 228,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_random \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--do"
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_random_kd.sh",
    "chars": 231,
    "preview": "python prune_kd.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_random \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n-"
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_scratch.sh",
    "chars": 228,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_reinit \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--do"
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_second_order_taylor.sh",
    "chars": 254,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_second_order_taylor \\\n--timesteps 100 \\\n--eta 0 "
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_taylor.sh",
    "chars": 228,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_taylor \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--do"
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_pruning_taylor_kd.sh",
    "chars": 234,
    "preview": "python prune_kd.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_pruning_taylor_kd \\\n--timesteps 100 \\\n--eta 0 \\\n--ni "
  },
  {
    "path": "ddpm_exp/scripts/old/run_cifar_train.sh",
    "chars": 230,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/ddim_cifar10_train_v2 \\\n--use_pretrained \\\n--timesteps 100 \\\n--eta 0 "
  },
  {
    "path": "ddpm_exp/scripts/prune_bedroom_ddpm.sh",
    "chars": 349,
    "preview": "#!/bin/bash\n\n# Execute the Python script with the provided arguments\npython prune.py \\\n--config \"bedroom.yml\" \\\n--timest"
  },
  {
    "path": "ddpm_exp/scripts/prune_bedroom_ddpm_test.sh",
    "chars": 356,
    "preview": "#!/bin/bash\n\n# Execute the Python script with the provided arguments\npython prune_test.py \\\n--config \"bedroom.yml\" \\\n--t"
  },
  {
    "path": "ddpm_exp/scripts/prune_celeba_ddpm.sh",
    "chars": 362,
    "preview": "#!/bin/bash\n\n# Execute the Python script with the provided arguments\npython prune.py \\\n--config \"celeba.yml\" \\\n--timeste"
  },
  {
    "path": "ddpm_exp/scripts/prune_celeba_ddpm_ssim.sh",
    "chars": 283,
    "preview": "python prune_ssim.py \\\n--config celeba.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc post_training \\\n--skip_type quad  "
  },
  {
    "path": "ddpm_exp/scripts/prune_church_ddpm.sh",
    "chars": 347,
    "preview": "#!/bin/bash\n\n# Execute the Python script with the provided arguments\npython prune.py \\\n--config \"church.yml\" \\\n--timeste"
  },
  {
    "path": "ddpm_exp/scripts/prune_church_ddpm_test.sh",
    "chars": 354,
    "preview": "#!/bin/bash\n\n# Execute the Python script with the provided arguments\npython prune_test.py \\\n--config \"church.yml\" \\\n--ti"
  },
  {
    "path": "ddpm_exp/scripts/prune_cifar_ddpm.sh",
    "chars": 258,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc post_training \\\n--skip_type quad  \\\n--"
  },
  {
    "path": "ddpm_exp/scripts/prune_cifar_ddpm_ssim.sh",
    "chars": 261,
    "preview": "python prune_ssim.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc post_training \\\n--skip_type quad "
  },
  {
    "path": "ddpm_exp/scripts/prune_cifar_ddpm_test.sh",
    "chars": 255,
    "preview": "python prune_test.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc post_training \\\n--skip_type quad "
  },
  {
    "path": "ddpm_exp/scripts/run_celeba.sh",
    "chars": 645,
    "preview": "#!/bin/bash\n\n# Execute the Python script with the provided arguments\npython prune.py \\\n--config \"celeba.yml\" \\\n--timeste"
  },
  {
    "path": "ddpm_exp/scripts/sample_bedroom_ddpm_pretrained.sh",
    "chars": 276,
    "preview": "python -m torch.distributed.launch --nproc_per_node=1 --master_port 22200 --use_env finetune.py \\\n--config bedroom.yml \\"
  },
  {
    "path": "ddpm_exp/scripts/sample_bedroom_ddpm_pruning.sh",
    "chars": 277,
    "preview": "python -m torch.distributed.launch --nproc_per_node=4 --master_port 22223 --use_env finetune.py \\\n--config bedroom.yml \\"
  },
  {
    "path": "ddpm_exp/scripts/sample_celeba_ddpm_pruning.sh",
    "chars": 199,
    "preview": "python finetune.py \\\n--config celeba.yml \\\n--exp $2 \\\n--sample \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc sample \\\n--ski"
  },
  {
    "path": "ddpm_exp/scripts/sample_celeba_pretrained.sh",
    "chars": 224,
    "preview": "python prune.py \\\n--config celeba.yml \\\n--exp run/sample/ddim_celeba_official \\\n--sample \\\n--use_pretrained \\\n--timestep"
  },
  {
    "path": "ddpm_exp/scripts/sample_church_ddpm_pruning.sh",
    "chars": 307,
    "preview": "python finetune.py \\\n--config church.yml \\\n--exp run/sample/church_ddpm_350k \\\n--sample \\\n--timesteps 100 \\\n--eta 0 \\\n--"
  },
  {
    "path": "ddpm_exp/scripts/sample_church_ddpm_pruning_old.sh",
    "chars": 276,
    "preview": "python -m torch.distributed.launch --nproc_per_node=4 --master_port 22221 --use_env finetune.py \\\n--config church.yml \\\n"
  },
  {
    "path": "ddpm_exp/scripts/sample_church_ddpm_test.sh",
    "chars": 307,
    "preview": "python finetune.py \\\n--config church.yml \\\n--exp run/sample/church_ddpm_350k \\\n--sample \\\n--timesteps 100 \\\n--eta 0 \\\n--"
  },
  {
    "path": "ddpm_exp/scripts/sample_church_pretrained.sh",
    "chars": 224,
    "preview": "python prune.py \\\n--config church.yml \\\n--exp run/sample/ddim_church_official \\\n--sample \\\n--use_pretrained \\\n--timestep"
  },
  {
    "path": "ddpm_exp/scripts/sample_cifar_ddpm_pruning.sh",
    "chars": 201,
    "preview": "python finetune.py \\\n--config cifar10.yml \\\n--exp \"$2\" \\\n--sample \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--doc sample \\\n--"
  },
  {
    "path": "ddpm_exp/scripts/sample_cifar_pretrained.sh",
    "chars": 226,
    "preview": "python prune.py \\\n--config cifar10.yml \\\n--exp run/sample/ddim_cifar10_official \\\n--sample \\\n--use_pretrained \\\n--timest"
  },
  {
    "path": "ddpm_exp/scripts/simple_celeba_our.sh",
    "chars": 281,
    "preview": "python finetune_simple.py \\\n--config celeba.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune_simple/celeba_ou"
  },
  {
    "path": "ddpm_exp/scripts/simple_cifar_our.sh",
    "chars": 258,
    "preview": "python finetune_simple.py \\\n--config cifar10.yml \\\n--timesteps 100 \\\n--eta 0 \\\n--ni \\\n--exp run/finetune_simple_v2/cifar"
  },
  {
    "path": "ddpm_exp/tools/extract_cifar10.py",
    "chars": 621,
    "preview": "import os\nimport torchvision\nfrom torchvision.datasets import CIFAR10\nfrom tqdm import tqdm\n\n# Define the path to the fo"
  },
  {
    "path": "ddpm_exp/tools/transform_weights.py",
    "chars": 232,
    "preview": "import torch\n\nstate = torch.load(\"model.ckpt.old\")\nold_dict = state[0]\nprint(state[0].keys())\nstate[0] = {pname.replace("
  },
  {
    "path": "ddpm_exp/torch_pruning/__init__.py",
    "chars": 89,
    "preview": "from .dependency import *\nfrom .pruner import *\nfrom . import _helpers, utils, importance"
  },
  {
    "path": "ddpm_exp/torch_pruning/_helpers.py",
    "chars": 3356,
    "preview": "import torch.nn as nn\nimport numpy as np\nimport torch\nfrom operator import add\nfrom numbers import Number\n\n\ndef is_scala"
  },
  {
    "path": "ddpm_exp/torch_pruning/dependency.py",
    "chars": 42732,
    "preview": "import typing\nimport warnings\nfrom numbers import Number\nfrom collections import namedtuple\n\nimport torch\nimport torch.n"
  },
  {
    "path": "ddpm_exp/torch_pruning/importance.py",
    "chars": 33072,
    "preview": "import abc\nimport torch\nimport torch.nn as nn\n\nimport typing\nfrom .pruner import function\nfrom ._helpers import _Flatten"
  },
  {
    "path": "ddpm_exp/torch_pruning/ops.py",
    "chars": 6836,
    "preview": "import torch.nn as nn\nfrom enum import IntEnum\n\n\nclass DummyMHA(nn.Module):\n    def __init__(self):\n        super(DummyM"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/__init__.py",
    "chars": 49,
    "preview": "from .function import *\nfrom .algorithms import *"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/__init__.py",
    "chars": 279,
    "preview": "from .metapruner import MetaPruner\nfrom .magnitude_based_pruner import MagnitudePruner\nfrom .batchnorm_scale_pruner impo"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py",
    "chars": 1614,
    "preview": "from numbers import Number\nfrom typing import Callable\nfrom .metapruner import MetaPruner\nfrom .scheduler import linear_"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/group_norm_pruner.py",
    "chars": 8060,
    "preview": "import torch\nimport math\nfrom .metapruner import MetaPruner\nfrom .scheduler import linear_scheduler\nfrom .. import funct"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/magnitude_based_pruner.py",
    "chars": 84,
    "preview": "from .metapruner import MetaPruner\n\nclass MagnitudePruner(MetaPruner):\n    pass\n    "
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/metapruner.py",
    "chars": 12611,
    "preview": "import torch\nimport torch.nn as nn\nimport typing\n\nfrom .scheduler import linear_scheduler\nfrom ..import function\nfrom .."
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/scaling_factor_pruner.py",
    "chars": 3358,
    "preview": "from numbers import Number\nfrom typing import Callable\nfrom .metapruner import MetaPruner\nfrom .scheduler import linear_"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/scheduler.py",
    "chars": 124,
    "preview": "\ndef linear_scheduler(ch_sparsity_dict, steps):\n    return [((i) / float(steps)) * ch_sparsity_dict for i in range(steps"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/algorithms/taylor_pruner.py",
    "chars": 5926,
    "preview": "import torch\nimport math\nfrom .metapruner import MetaPruner\nfrom .scheduler import linear_scheduler\nfrom .. import funct"
  },
  {
    "path": "ddpm_exp/torch_pruning/pruner/function.py",
    "chars": 22076,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .. import ops\n\nfrom copy import deepcopy\nfrom functools import reduce\nfrom oper"
  },
  {
    "path": "ddpm_exp/torch_pruning/utils/__init__.py",
    "chars": 65,
    "preview": "from .utils import *\nfrom .op_counter import count_ops_and_params"
  },
  {
    "path": "ddpm_exp/torch_pruning/utils/op_counter.py",
    "chars": 15847,
    "preview": "'''\nThis opcounter is adapted from https://github.com/sovrasov/flops-counter.pytorch\n\nCopyright (C) 2021 Sovrasov V. - A"
  },
  {
    "path": "ddpm_exp/torch_pruning/utils/utils.py",
    "chars": 5557,
    "preview": "from ..ops import TORCH_CONV, TORCH_BATCHNORM, TORCH_PRELU, TORCH_LINEAR\nfrom ..ops import module2type\nimport torch\nfrom"
  },
  {
    "path": "ddpm_exp/utils.py",
    "chars": 816,
    "preview": "import torch, os\nfrom glob import glob\nfrom PIL import Image\n\nclass UnlabeledImageFolder(torch.utils.data.Dataset):\n    "
  },
  {
    "path": "ddpm_prune.py",
    "chars": 6405,
    "preview": "from diffusers import DiffusionPipeline, DDPMPipeline, DDIMPipeline, DDIMScheduler, DDPMScheduler\nfrom diffusers.models "
  },
  {
    "path": "ddpm_sample.py",
    "chars": 3843,
    "preview": "from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel\nimport argparse, os, torch\nfrom tqdm import tqdm\nimport t"
  },
  {
    "path": "ddpm_train.py",
    "chars": 23097,
    "preview": "# Modifed from https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation\n\nimport argpar"
  },
  {
    "path": "diffusers/__init__.py",
    "chars": 8090,
    "preview": "__version__ = \"0.17.0.dev0\"\n\nfrom .configuration_utils import ConfigMixin\nfrom .utils import (\n    OptionalDependencyNot"
  },
  {
    "path": "diffusers/commands/__init__.py",
    "chars": 920,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/commands/diffusers_cli.py",
    "chars": 1200,
    "preview": "#!/usr/bin/env python\n# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License,"
  },
  {
    "path": "diffusers/commands/env.py",
    "chars": 2870,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/configuration_utils.py",
    "chars": 29814,
    "preview": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserve"
  },
  {
    "path": "diffusers/dependency_versions_check.py",
    "chars": 1756,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/dependency_versions_table.py",
    "chars": 1320,
    "preview": "# THIS FILE HAS BEEN AUTOGENERATED. To update:\n# 1. modify the `_deps` dict in setup.py\n# 2. run `make deps_table_update"
  },
  {
    "path": "diffusers/experimental/README.md",
    "chars": 284,
    "preview": "# 🧨 Diffusers Experimental\n\nWe are adding experimental code to support novel applications and usages of the Diffusers li"
  },
  {
    "path": "diffusers/experimental/__init__.py",
    "chars": 38,
    "preview": "from .rl import ValueGuidedRLPipeline\n"
  },
  {
    "path": "diffusers/experimental/rl/__init__.py",
    "chars": 57,
    "preview": "from .value_guided_sampling import ValueGuidedRLPipeline\n"
  },
  {
    "path": "diffusers/experimental/rl/value_guided_sampling.py",
    "chars": 6133,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/image_processor.py",
    "chars": 8080,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/loaders.py",
    "chars": 69051,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/README.md",
    "chars": 118,
    "preview": "# Models\n\nFor more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models)."
  },
  {
    "path": "diffusers/models/__init__.py",
    "chars": 1446,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/attention.py",
    "chars": 14182,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/attention_flax.py",
    "chars": 17680,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/attention_processor.py",
    "chars": 52895,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/autoencoder_kl.py",
    "chars": 14692,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/controlnet.py",
    "chars": 25899,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/controlnet_flax.py",
    "chars": 16215,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/cross_attention.py",
    "chars": 4875,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/dual_transformer_2d.py",
    "chars": 7297,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/embeddings.py",
    "chars": 15784,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/embeddings_flax.py",
    "chars": 3443,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/modeling_flax_pytorch_utils.py",
    "chars": 4601,
    "preview": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "diffusers/models/modeling_flax_utils.py",
    "chars": 25840,
    "preview": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "diffusers/models/modeling_pytorch_flax_utils.py",
    "chars": 6974,
    "preview": "# coding=utf-8\r\n# Copyright 2023 The HuggingFace Inc. team.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"L"
  },
  {
    "path": "diffusers/models/modeling_utils.py",
    "chars": 39833,
    "preview": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserve"
  },
  {
    "path": "diffusers/models/prior_transformer.py",
    "chars": 8648,
    "preview": "from dataclasses import dataclass\nfrom typing import Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom "
  },
  {
    "path": "diffusers/models/resnet.py",
    "chars": 34822,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The"
  },
  {
    "path": "diffusers/models/resnet_flax.py",
    "chars": 4021,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/t5_film_transformer.py",
    "chars": 11824,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/transformer_2d.py",
    "chars": 15801,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/transformer_temporal.py",
    "chars": 7448,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_1d.py",
    "chars": 10653,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_1d_blocks.py",
    "chars": 24865,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_2d.py",
    "chars": 14666,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_2d_blocks.py",
    "chars": 110736,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_2d_blocks_flax.py",
    "chars": 14242,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_2d_condition.py",
    "chars": 38781,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_2d_condition_flax.py",
    "chars": 15136,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_3d_blocks.py",
    "chars": 23917,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/unet_3d_condition.py",
    "chars": 25055,
    "preview": "# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.\n# Copyright 2023 The ModelScope Team."
  },
  {
    "path": "diffusers/models/vae.py",
    "chars": 15235,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/vae_flax.py",
    "chars": 31817,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/models/vq_model.py",
    "chars": 6469,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/optimization.py",
    "chars": 14541,
    "preview": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
  },
  {
    "path": "diffusers/pipeline_utils.py",
    "chars": 1147,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/README.md",
    "chars": 17601,
    "preview": "# 🧨 Diffusers Pipelines\n\nPipelines provide a simple way to run state-of-the-art diffusion models in inference.\nMost diff"
  },
  {
    "path": "diffusers/pipelines/__init__.py",
    "chars": 5760,
    "preview": "from ..utils import (\n    OptionalDependencyNotAvailable,\n    is_flax_available,\n    is_k_diffusion_available,\n    is_li"
  },
  {
    "path": "diffusers/pipelines/alt_diffusion/__init__.py",
    "chars": 1346,
    "preview": "from dataclasses import dataclass\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport PIL\nfrom PIL impor"
  },
  {
    "path": "diffusers/pipelines/alt_diffusion/modeling_roberta_series.py",
    "chars": 5580,
    "preview": "from dataclasses import dataclass\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import nn\nfrom transformer"
  },
  {
    "path": "diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py",
    "chars": 38177,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py",
    "chars": 40366,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/audio_diffusion/__init__.py",
    "chars": 82,
    "preview": "from .mel import Mel\nfrom .pipeline_audio_diffusion import AudioDiffusionPipeline\n"
  },
  {
    "path": "diffusers/pipelines/audio_diffusion/mel.py",
    "chars": 5381,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py",
    "chars": 10699,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/audioldm/__init__.py",
    "chars": 503,
    "preview": "from ...utils import (\n    OptionalDependencyNotAvailable,\n    is_torch_available,\n    is_transformers_available,\n    is"
  },
  {
    "path": "diffusers/pipelines/audioldm/pipeline_audioldm.py",
    "chars": 28856,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/controlnet/__init__.py",
    "chars": 817,
    "preview": "from ...utils import (\n    OptionalDependencyNotAvailable,\n    is_flax_available,\n    is_torch_available,\n    is_transfo"
  },
  {
    "path": "diffusers/pipelines/controlnet/multicontrolnet.py",
    "chars": 2424,
    "preview": "from typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\n\nfrom ...models.controlnet"
  },
  {
    "path": "diffusers/pipelines/controlnet/pipeline_controlnet.py",
    "chars": 50207,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py",
    "chars": 54122,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py",
    "chars": 61489,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/controlnet/pipeline_flax_controlnet.py",
    "chars": 23354,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/dance_diffusion/__init__.py",
    "chars": 61,
    "preview": "from .pipeline_dance_diffusion import DanceDiffusionPipeline\n"
  },
  {
    "path": "diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py",
    "chars": 5637,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/ddim/__init__.py",
    "chars": 40,
    "preview": "from .pipeline_ddim import DDIMPipeline\n"
  },
  {
    "path": "diffusers/pipelines/ddim/pipeline_ddim.py",
    "chars": 5582,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/ddpm/__init__.py",
    "chars": 40,
    "preview": "from .pipeline_ddpm import DDPMPipeline\n"
  },
  {
    "path": "diffusers/pipelines/ddpm/pipeline_ddpm.py",
    "chars": 4495,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/__init__.py",
    "chars": 2163,
    "preview": "from dataclasses import dataclass\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport PIL\n\nfrom ...utils"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/pipeline_if.py",
    "chars": 39510,
    "preview": "import html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Any, Callable, Dict, List, Optional, U"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py",
    "chars": 44112,
    "preview": "import html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Any, Callable, Dict, List, Optional, U"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py",
    "chars": 49036,
    "preview": "import html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Any, Callable, Dict, List, Optional, U"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py",
    "chars": 49324,
    "preview": "import html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Any, Callable, Dict, List, Optional, U"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py",
    "chars": 54059,
    "preview": "import html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Any, Callable, Dict, List, Optional, U"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py",
    "chars": 43849,
    "preview": "import html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Any, Callable, Dict, List, Optional, U"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/safety_checker.py",
    "chars": 2117,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom transformers import CLIPConfig, CLIPVisionModelWithProjection"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/timesteps.py",
    "chars": 5164,
    "preview": "fast27_timesteps = [\n    999,\n    800,\n    799,\n    600,\n    599,\n    500,\n    400,\n    399,\n    377,\n    355,\n    333,\n"
  },
  {
    "path": "diffusers/pipelines/deepfloyd_if/watermark.py",
    "chars": 1595,
    "preview": "from typing import List\n\nimport PIL\nimport torch\nfrom PIL import Image\n\nfrom ...configuration_utils import ConfigMixin\nf"
  },
  {
    "path": "diffusers/pipelines/dit/__init__.py",
    "chars": 38,
    "preview": "from .pipeline_dit import DiTPipeline\n"
  },
  {
    "path": "diffusers/pipelines/dit/pipeline_dit.py",
    "chars": 8478,
    "preview": "# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)\n# William Peebles and Saining Xie\n#\n# Copyright (c) 2021 Op"
  },
  {
    "path": "diffusers/pipelines/latent_diffusion/__init__.py",
    "chars": 243,
    "preview": "from ...utils import is_transformers_available\nfrom .pipeline_latent_diffusion_superresolution import LDMSuperResolution"
  },
  {
    "path": "diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py",
    "chars": 32468,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py",
    "chars": 6959,
    "preview": "import inspect\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL\nimport torch\nimport torch."
  },
  {
    "path": "diffusers/pipelines/latent_diffusion_uncond/__init__.py",
    "chars": 58,
    "preview": "from .pipeline_latent_diffusion_uncond import LDMPipeline\n"
  },
  {
    "path": "diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py",
    "chars": 4888,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "diffusers/pipelines/onnx_utils.py",
    "chars": 8282,
    "preview": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserve"
  }
]

// ... and 239 more files (download for full content)

About this extraction

This page contains the full source code of the VainF/Diff-Pruning GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 439 files (8.8 MB), approximately 2.3M tokens, and a symbol index with 3841 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!