Copy disabled (too large)
Download .txt
Showing preview only (10,216K chars total). Download the full file to get everything.
Repository: cure-lab/PnPInversion
Branch: main
Commit: 07f97f448150
Files: 339
Total size: 9.7 MB
Directory structure:
gitextract_s8_0ppc8/
├── .gitignore
├── README.md
├── environment/
│ ├── edict_requirements.txt
│ ├── instructdiffusion_requirements.txt
│ ├── masactrl_requirements.txt
│ ├── p2p_requirements.txt
│ ├── pix2pix_zero_requirements.txt
│ └── pnp_requirements.txt
├── evaluation/
│ ├── evaluate.py
│ └── matrics_calculator.py
├── models/
│ ├── InstructDiffusion/
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── instruct_diffusion.yaml
│ │ ├── dataset/
│ │ │ ├── README.md
│ │ │ ├── editing/
│ │ │ │ └── edit_zip_dataset.py
│ │ │ ├── low_level/
│ │ │ │ ├── lowlevel_clwd.py
│ │ │ │ ├── lowlevel_gopro.py
│ │ │ │ ├── lowlevel_reds.py
│ │ │ │ └── lowlevel_sidd.py
│ │ │ ├── pose/
│ │ │ │ └── pose.py
│ │ │ ├── prompt/
│ │ │ │ ├── color_list_train_small.txt
│ │ │ │ ├── prompt_deblur.txt
│ │ │ │ ├── prompt_denoise.txt
│ │ │ │ ├── prompt_dewatermark.txt
│ │ │ │ ├── prompt_pose.txt
│ │ │ │ └── prompt_seg.txt
│ │ │ ├── seg/
│ │ │ │ ├── coco_stuff.py
│ │ │ │ ├── grefcoco.py
│ │ │ │ ├── grefcoco_segmentation.py
│ │ │ │ ├── refcoco.py
│ │ │ │ └── refcoco_segmentation.py
│ │ │ └── utils/
│ │ │ └── zip_manager.py
│ │ ├── edit_app.py
│ │ ├── edit_cli.py
│ │ ├── environment.yaml
│ │ ├── main.py
│ │ ├── scripts/
│ │ │ ├── convert_ckpt.py
│ │ │ ├── download_pretrained_sd.sh
│ │ │ ├── inference_example.sh
│ │ │ └── run_multinode.sh
│ │ ├── stable_diffusion/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── Stable_Diffusion_v1_Model_Card.md
│ │ │ ├── assets/
│ │ │ │ ├── results.gif.REMOVED.git-id
│ │ │ │ ├── stable-samples/
│ │ │ │ │ ├── img2img/
│ │ │ │ │ │ ├── upscaling-in.png.REMOVED.git-id
│ │ │ │ │ │ └── upscaling-out.png.REMOVED.git-id
│ │ │ │ │ └── txt2img/
│ │ │ │ │ ├── merged-0005.png.REMOVED.git-id
│ │ │ │ │ ├── merged-0006.png.REMOVED.git-id
│ │ │ │ │ └── merged-0007.png.REMOVED.git-id
│ │ │ │ └── txt2img-preview.png.REMOVED.git-id
│ │ │ ├── configs/
│ │ │ │ ├── autoencoder/
│ │ │ │ │ ├── autoencoder_kl_16x16x16.yaml
│ │ │ │ │ ├── autoencoder_kl_32x32x4.yaml
│ │ │ │ │ ├── autoencoder_kl_64x64x3.yaml
│ │ │ │ │ └── autoencoder_kl_8x8x64.yaml
│ │ │ │ ├── latent-diffusion/
│ │ │ │ │ ├── celebahq-ldm-vq-4.yaml
│ │ │ │ │ ├── cin-ldm-vq-f8.yaml
│ │ │ │ │ ├── cin256-v2.yaml
│ │ │ │ │ ├── ffhq-ldm-vq-4.yaml
│ │ │ │ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ │ │ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ │ │ │ └── txt2img-1p4B-eval.yaml
│ │ │ │ ├── retrieval-augmented-diffusion/
│ │ │ │ │ └── 768x768.yaml
│ │ │ │ └── stable-diffusion/
│ │ │ │ └── v1-inference.yaml
│ │ │ ├── environment.yaml
│ │ │ ├── ldm/
│ │ │ │ ├── lr_scheduler.py
│ │ │ │ ├── models/
│ │ │ │ │ ├── autoencoder.py
│ │ │ │ │ └── diffusion/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── classifier.py
│ │ │ │ │ ├── ddim.py
│ │ │ │ │ ├── ddpm.py
│ │ │ │ │ ├── ddpm_edit.py
│ │ │ │ │ ├── dpm_solver/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── dpm_solver.py
│ │ │ │ │ │ └── sampler.py
│ │ │ │ │ └── plms.py
│ │ │ │ ├── modules/
│ │ │ │ │ ├── attention.py
│ │ │ │ │ ├── diffusionmodules/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── model.py
│ │ │ │ │ │ ├── openaimodel.py
│ │ │ │ │ │ └── util.py
│ │ │ │ │ ├── distributions/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── distributions.py
│ │ │ │ │ ├── ema.py
│ │ │ │ │ ├── encoders/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── modules.py
│ │ │ │ │ ├── image_degradation/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── bsrgan.py
│ │ │ │ │ │ ├── bsrgan_light.py
│ │ │ │ │ │ └── utils_image.py
│ │ │ │ │ ├── losses/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── contperceptual.py
│ │ │ │ │ │ └── vqperceptual.py
│ │ │ │ │ └── x_transformer.py
│ │ │ │ └── util.py
│ │ │ ├── main.py
│ │ │ ├── models/
│ │ │ │ ├── first_stage_models/
│ │ │ │ │ ├── kl-f16/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── kl-f32/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── kl-f4/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── kl-f8/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f16/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f4/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f4-noattn/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f8/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ └── vq-f8-n256/
│ │ │ │ │ └── config.yaml
│ │ │ │ └── ldm/
│ │ │ │ ├── bsr_sr/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── celeba256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── cin256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── ffhq256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── inpainting_big/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── layout2img-openimages256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── lsun_beds256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── lsun_churches256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── semantic_synthesis256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── semantic_synthesis512/
│ │ │ │ │ └── config.yaml
│ │ │ │ └── text2img256/
│ │ │ │ └── config.yaml
│ │ │ ├── notebook_helpers.py
│ │ │ ├── scripts/
│ │ │ │ ├── download_first_stages.sh
│ │ │ │ ├── download_models.sh
│ │ │ │ ├── img2img.py
│ │ │ │ ├── inpaint.py
│ │ │ │ ├── knn2img.py
│ │ │ │ ├── latent_imagenet_diffusion.ipynb.REMOVED.git-id
│ │ │ │ ├── sample_diffusion.py
│ │ │ │ ├── tests/
│ │ │ │ │ └── test_watermark.py
│ │ │ │ ├── train_searcher.py
│ │ │ │ └── txt2img.py
│ │ │ └── setup.py
│ │ └── utils/
│ │ ├── deepspeed.py
│ │ ├── logger.py
│ │ └── utils.py
│ ├── edict/
│ │ ├── edict_functions.py
│ │ └── my_diffusers/
│ │ ├── __init__.py
│ │ ├── commands/
│ │ │ ├── __init__.py
│ │ │ ├── diffusers_cli.py
│ │ │ └── env.py
│ │ ├── configuration_utils.py
│ │ ├── dependency_versions_check.py
│ │ ├── dependency_versions_table.py
│ │ ├── dynamic_modules_utils.py
│ │ ├── hub_utils.py
│ │ ├── modeling_utils.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── embeddings.py
│ │ │ ├── resnet.py
│ │ │ ├── unet_2d.py
│ │ │ ├── unet_2d_condition.py
│ │ │ ├── unet_blocks.py
│ │ │ └── vae.py
│ │ ├── onnx_utils.py
│ │ ├── optimization.py
│ │ ├── pipeline_utils.py
│ │ ├── pipelines/
│ │ │ ├── __init__.py
│ │ │ ├── ddim/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_ddim.py
│ │ │ ├── ddpm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_ddpm.py
│ │ │ ├── latent_diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_latent_diffusion.py
│ │ │ ├── latent_diffusion_uncond/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_latent_diffusion_uncond.py
│ │ │ ├── pndm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_pndm.py
│ │ │ ├── score_sde_ve/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_score_sde_ve.py
│ │ │ ├── stable_diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── pipeline_stable_diffusion.py
│ │ │ │ ├── pipeline_stable_diffusion_img2img.py
│ │ │ │ ├── pipeline_stable_diffusion_inpaint.py
│ │ │ │ ├── pipeline_stable_diffusion_onnx.py
│ │ │ │ └── safety_checker.py
│ │ │ └── stochastic_karras_ve/
│ │ │ ├── __init__.py
│ │ │ └── pipeline_stochastic_karras_ve.py
│ │ ├── schedulers/
│ │ │ ├── __init__.py
│ │ │ ├── scheduling_ddim.py
│ │ │ ├── scheduling_ddpm.py
│ │ │ ├── scheduling_karras_ve.py
│ │ │ ├── scheduling_lms_discrete.py
│ │ │ ├── scheduling_pndm.py
│ │ │ ├── scheduling_sde_ve.py
│ │ │ ├── scheduling_sde_vp.py
│ │ │ └── scheduling_utils.py
│ │ ├── testing_utils.py
│ │ ├── training_utils.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── dummy_scipy_objects.py
│ │ ├── dummy_transformers_and_inflect_and_unidecode_objects.py
│ │ ├── dummy_transformers_and_onnx_objects.py
│ │ ├── dummy_transformers_objects.py
│ │ ├── import_utils.py
│ │ ├── logging.py
│ │ ├── model_card_template.md
│ │ └── outputs.py
│ ├── edit_friendly_ddm/
│ │ ├── inversion_utils.py
│ │ ├── ptp_classes.py
│ │ ├── ptp_utils.py
│ │ └── seq_aligner.py
│ ├── instructpix2pix/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── generate.yaml
│ │ │ └── train.yaml
│ │ ├── dataset_creation/
│ │ │ ├── generate_img_dataset.py
│ │ │ ├── generate_txt_dataset.py
│ │ │ ├── prepare_dataset.py
│ │ │ └── prepare_for_gpt.py
│ │ ├── edit_app.py
│ │ ├── edit_cli.py
│ │ ├── edit_dataset.py
│ │ ├── environment.yaml
│ │ ├── main.py
│ │ ├── metrics/
│ │ │ ├── clip_similarity.py
│ │ │ └── compute_metrics.py
│ │ ├── prompt_app.py
│ │ ├── scripts/
│ │ │ ├── download_checkpoints.sh
│ │ │ ├── download_data.sh
│ │ │ └── download_pretrained_sd.sh
│ │ └── stable_diffusion/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── Stable_Diffusion_v1_Model_Card.md
│ │ ├── assets/
│ │ │ ├── results.gif.REMOVED.git-id
│ │ │ ├── stable-samples/
│ │ │ │ ├── img2img/
│ │ │ │ │ ├── upscaling-in.png.REMOVED.git-id
│ │ │ │ │ └── upscaling-out.png.REMOVED.git-id
│ │ │ │ └── txt2img/
│ │ │ │ ├── merged-0005.png.REMOVED.git-id
│ │ │ │ ├── merged-0006.png.REMOVED.git-id
│ │ │ │ └── merged-0007.png.REMOVED.git-id
│ │ │ └── txt2img-preview.png.REMOVED.git-id
│ │ ├── configs/
│ │ │ ├── autoencoder/
│ │ │ │ ├── autoencoder_kl_16x16x16.yaml
│ │ │ │ ├── autoencoder_kl_32x32x4.yaml
│ │ │ │ ├── autoencoder_kl_64x64x3.yaml
│ │ │ │ └── autoencoder_kl_8x8x64.yaml
│ │ │ ├── latent-diffusion/
│ │ │ │ ├── celebahq-ldm-vq-4.yaml
│ │ │ │ ├── cin-ldm-vq-f8.yaml
│ │ │ │ ├── cin256-v2.yaml
│ │ │ │ ├── ffhq-ldm-vq-4.yaml
│ │ │ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ │ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ │ │ └── txt2img-1p4B-eval.yaml
│ │ │ ├── retrieval-augmented-diffusion/
│ │ │ │ └── 768x768.yaml
│ │ │ └── stable-diffusion/
│ │ │ └── v1-inference.yaml
│ │ ├── environment.yaml
│ │ ├── ldm/
│ │ │ ├── lr_scheduler.py
│ │ │ ├── models/
│ │ │ │ ├── autoencoder.py
│ │ │ │ └── diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── classifier.py
│ │ │ │ ├── ddim.py
│ │ │ │ ├── ddpm.py
│ │ │ │ ├── ddpm_edit.py
│ │ │ │ ├── dpm_solver/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── dpm_solver.py
│ │ │ │ │ └── sampler.py
│ │ │ │ └── plms.py
│ │ │ ├── modules/
│ │ │ │ ├── attention.py
│ │ │ │ ├── diffusionmodules/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── model.py
│ │ │ │ │ ├── openaimodel.py
│ │ │ │ │ └── util.py
│ │ │ │ ├── distributions/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── distributions.py
│ │ │ │ ├── ema.py
│ │ │ │ ├── encoders/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── modules.py
│ │ │ │ ├── image_degradation/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── bsrgan.py
│ │ │ │ │ ├── bsrgan_light.py
│ │ │ │ │ └── utils_image.py
│ │ │ │ ├── losses/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── contperceptual.py
│ │ │ │ │ └── vqperceptual.py
│ │ │ │ └── x_transformer.py
│ │ │ └── util.py
│ │ ├── main.py
│ │ ├── models/
│ │ │ ├── first_stage_models/
│ │ │ │ ├── kl-f16/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── kl-f32/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── kl-f4/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── kl-f8/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f16/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f4/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f4-noattn/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f8/
│ │ │ │ │ └── config.yaml
│ │ │ │ └── vq-f8-n256/
│ │ │ │ └── config.yaml
│ │ │ └── ldm/
│ │ │ ├── bsr_sr/
│ │ │ │ └── config.yaml
│ │ │ ├── celeba256/
│ │ │ │ └── config.yaml
│ │ │ ├── cin256/
│ │ │ │ └── config.yaml
│ │ │ ├── ffhq256/
│ │ │ │ └── config.yaml
│ │ │ ├── inpainting_big/
│ │ │ │ └── config.yaml
│ │ │ ├── layout2img-openimages256/
│ │ │ │ └── config.yaml
│ │ │ ├── lsun_beds256/
│ │ │ │ └── config.yaml
│ │ │ ├── lsun_churches256/
│ │ │ │ └── config.yaml
│ │ │ ├── semantic_synthesis256/
│ │ │ │ └── config.yaml
│ │ │ ├── semantic_synthesis512/
│ │ │ │ └── config.yaml
│ │ │ └── text2img256/
│ │ │ └── config.yaml
│ │ ├── notebook_helpers.py
│ │ ├── scripts/
│ │ │ ├── download_first_stages.sh
│ │ │ ├── download_models.sh
│ │ │ ├── img2img.py
│ │ │ ├── inpaint.py
│ │ │ ├── knn2img.py
│ │ │ ├── latent_imagenet_diffusion.ipynb.REMOVED.git-id
│ │ │ ├── sample_diffusion.py
│ │ │ ├── tests/
│ │ │ │ └── test_watermark.py
│ │ │ ├── train_searcher.py
│ │ │ └── txt2img.py
│ │ └── setup.py
│ ├── masactrl/
│ │ ├── diffuser_utils.py
│ │ ├── masactrl.py
│ │ └── masactrl_utils.py
│ ├── p2p/
│ │ ├── attention_control.py
│ │ ├── inversion.py
│ │ ├── p2p_guidance_forward.py
│ │ ├── proximal_guidance_forward.py
│ │ ├── scheduler_dev.py
│ │ └── seq_aligner.py
│ ├── p2p_editor.py
│ ├── pix2pix_zero/
│ │ ├── base_pipeline.py
│ │ ├── cross_attention.py
│ │ ├── ddim_inv.py
│ │ ├── edit_directions.py
│ │ ├── edit_pipeline.py
│ │ └── scheduler.py
│ └── stylediffusion/
│ ├── clip_util.py
│ ├── global_var.py
│ ├── inversion.py
│ ├── ptp_utils_v.py
│ ├── seq_aligner.py
│ └── utils.py
├── run_editing_blended_latent_diffusion.py
├── run_editing_edict.py
├── run_editing_edit_friendly_p2p.py
├── run_editing_instructdiffusion.py
├── run_editing_instructpix2pix.py
├── run_editing_masactrl.py
├── run_editing_p2p.py
├── run_editing_p2p_one_image.ipynb
├── run_editing_p2p_one_image.py
├── run_editing_pix2pix_zero.py
├── run_editing_pnp.py
├── run_editing_stylediffusion.py
└── utils/
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
data
__pycache__
.vscode
output
*.csv
*.out
*.bash
================================================
FILE: README.md
================================================
# PnPInversion
This repository contains the implementation of the ICLR2024 paper "PnP Inversion: Boosting Diffusion-based Editing with 3 Lines of Code"
Keywords: Diffusion Model, Image Inversion, Image Editing
> [Xuan Ju](https://github.com/juxuan27)<sup>12</sup>, [Ailing Zeng](https://ailingzeng.site/)<sup>2*</sup>, [Yuxuan Bian](https://github.com/TreastBean)<sup>1</sup>, [Shaoteng Liu](https://www.shaotengliu.com/)<sup>1</sup>, [Qiang Xu](https://cure-lab.github.io/)<sup>1*</sup><br>
> <sup>1</sup>The Chinese University of Hong Kong <sup>2</sup>International Digital Economy Academy <sup>*</sup>Corresponding Author
<p align="center">
<a href="https://cure-lab.github.io/PnPInversion/">Project Page</a> |
<a href="https://arxiv.org/abs/2310.01506">Arxiv</a> |
<a href="https://readpaper.com/paper/4807149696887816193">Readpaper</a> |
<a href="https://forms.gle/hVMkTABb4uvZVjme9">Benchmark</a> |
<a href="https://github.com/cure-lab/DirectInversion">Code</a> |
<a href="https://drive.google.com/file/d/1HGr4ETPa7w-08KKOMhfxhngzQ9Y9Nj4H/view">Video</a> |
</p>
**📖 Table of Contents**
- [Method Overview](#method-overview)
- [Getting Started](#getting-started)
- [Environment Requirement](#environment-requirement)
- [Benchmark Download](#benchmark-download)
- [Running Scripts](#running-scripts)
- [Inference](#inference)
- [Evaluation](#evaluation)
- [Quantitative Results](#quantitative-results)
- [Qualitative Results](#qualitative-results)
- [Cite Us](#cite-us)
- [Acknowledgement](#acknowledgement)
## 🛠️ Method Overview
<span id="method-overview"></span>
Text-guided diffusion models revolutionize image generation and editing, offering exceptional realism and diversity. Specifically, in the context of diffusion-based editing, common practice begins with a source image and a target prompt for editing. It involves obtaining a noisy latent vector corresponding to the source image using the diffusion model, which is then supplied to separate source and target diffusion branches for editing. The accuracy of this inversion process significantly impacts the final editing outcome, influencing both *essential content preservation* of the source image and *edit fidelity* according to the target prompt.
Previous inversion techniques attempted to find a unified solution in both the source and target diffusion branches. However, theoretical and empirical analysis shows that, in fact, a disentangling of the two branches leads to a clear separation of the responsibility for essential content preservation and edit fidelity, thus leading to better results in both aspects. In this paper, we introduce a novel technique called “**PnP Inversion**,” which rectifies inversion deviations directly within the source diffusion branch using just three lines of code, while leaving the target diffusion branch unaltered. To systematically evaluate image editing performance, we present **PIE-Bench**, an editing benchmark featuring 700 images with diverse scenes and editing types, complemented by versatile annotations. Our evaluation metrics, with a focus on editability and structure/background preservation, demonstrate the superior edit performance and inference speed of PnP Inversion across eight editing methods compared to five inversion techniques.


## 🚀 Getting Started
<span id="getting-started"></span>
### Environment Requirement 🌍
<span id="environment-requirement"></span>
This is important!!! Since different models have different python environmnet requirements (e.g. diffusers' version), we list the environmnet in the folder "environment", detailed as follows:
- p2p_requirements.txt: for models in `run_editing_p2p.py`, `run_editing_blended_latent_diffusion.py`, `run_editing_stylediffusion.py`, and `run_editing_edit_friendly_p2p.py`
- instructdiffusion_requirements.txt: for models in `run_editing_instructdiffusion.py` and `run_editing_instructpix2pix.py`
- masactrl_requirements.txt: for models in `run_editing_masactrl.py`
- pnp_requirements.txt: for models in `run_editing_pnp.py`
- pix2pix_zero_requirements.txt: for models in `run_editing_pix2pix_zero.py`
- edict_requirements.txt: for models in `run_editing_edict.py`
For example, if you want to use the models in `run_editing_p2p.py`, you need to install the environment as follows:
```shell
conda create -n p2p python=3.9 -y
conda activate p2p
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install -r environment/p2p_requirements.txt
```
### Benchmark Download ⬇️
<span id="benchmark-download"></span>
You can download the benchmark PIE-Bench (Prompt-driven Image Editing Benchmark) [here](https://forms.gle/hVMkTABb4uvZVjme9). The data structure should be like:
```python
|-- data
|-- annotation_images
|-- 0_random_140
|-- 000000000000.jpg
|-- 000000000001.jpg
|-- ...
|-- 1_change_object_80
|-- 1_artificial
|-- 1_animal
|-- 111000000000.jpg
|-- 111000000001.jpg
|-- ...
|-- 2_human
|-- 3_indoor
|-- 4_outdoor
|-- 2_natural
|-- ...
|-- ...
|-- mapping_file_ti2i_benchmark.json # the mapping file of TI2I benchmark, contains editing text
|-- mapping_file.json # the mapping file of PIE-Bench, contains editing text, blended word, and mask annotation
```
**PIE-Bench Benchmark:**
<details> <summary>containing 700 images with 10 types of editing. Folder name in "annotation images" indicates the editing type. [Unfold for details] </summary>
| Folder Name | Editing Type | Explanation |
| :-----: | :----: | :----: |
| 0_random_140 | 0. random editing | random prompt written by volunteers or examples in previous research. 140 images in total. |
| 1_change_object_80 | 1. change object | change an object to another, e.g., dot to cat. 80 images in total. |
| 2_add_object_80 | 2. add object | add an object, e.g., add flowers. 80 images in total. |
| 3_delete_object_80 | 3. delete object | delete an object, e.g., delete the clouds in the image. 80 images in total. |
| 4_change_attribute_content_40 | 4. change sth's content | change the content of sth, e.g., change a smiling man to an angry man by editing his facial expression. 40 images in total. |
| 5_change_attribute_pose_40 | 5. change sth's pose | change the pose of sth, e.g., change a standing dog to a running dog. 40 images in total. |
| 6_change_attribute_color_40 | 6. change sth's color | change the color of sth, e.g., change a red heart to a pink heart. 40 images in total. |
| 7_change_attribute_material_40 | 7. change sth's material | change the material of sth, e.g., change a wooden table to a glass table. 40 images in total. |
| 8_change_background_80 | 8. change image background | change the image background, e.g., change white background to grasses. 80 images in total. |
| 9_change_style_80 | 9. change image style | change the image style, e.g., change a photo to watercolor. 80 images in total. |
In editing type 1-9, we equally distribut the images to artifical images and natural images *(Noted that both these two categories are real images, artifical images are paintings or other human-generated images, real images are photos)*. In both two categories, images are equally distributed to animal, human, indoor scene and outdoor scene.
</details>
<details> <summary> The "mapping_file_ti2i_benchmark.json" contains annotation of editing text, blended word, and mask annotation for PIE-Bench. [Unfold for details] </summary>
The mapping_file_ti2i_benchmark.json contains a dict with following structure:
```python
{
"000000000000": {
"image_path": "0_random_140/000000000000.jpg", # image path
"original_prompt": "a slanted mountain bicycle on the road in front of a building", # image prompt of the original image, [] shows the difference with editing_prompt
"editing_prompt": "a slanted [rusty] mountain bicycle on the road in front of a building", # image prompt of the edited image, [] shows the difference with original_prompt
"editing_instruction": "Make the frame of the bike rusty", # image editing instruction
"editing_type_id": "0", # image editing type
"blended_word": "bicycle bicycle", # the word to be edited
"mask": [...] # mask with RLE encode, the part that needed to be edited is 1, otherwise 0.
},
...
}
```
</details>
**TI2I Benchmark:**
We also add [TI2I benchmark](https://pnp-diffusion.github.io/) in the data for ease of use. TI2I benchmark contains 55 images and edited image prompt for each image.
The images are provided in data/annotation_images/ti2i_benchmark and the mapping file is provided in data/mapping_file_ti2i_benchmark.json.
## 🏃🏼 Running Scripts
<span id="running-scripts"></span>
### Inference 📜
<span id="inference"></span>
**Run the Benchmark**
You can run the whole image editing results through `run_editing_p2p.py`, `run_editing_edit_friendly_p2p.py`, `run_editing_masactrl.py`, `run_editing_pnp.py`, `run_editing_edict.py`, `run_editing_pix2pix_zero.py`, `run_editing_instructdiffusion.py`, `run_editing_blended_latent_diffusion.py`,`run_editing_stylediffusion.py`, and `run_editing_instructpix2pix.py`. These python file contains models as follows (please unfold):
<details> <summary> run_editing_p2p.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| DDIM | Prompt-to-Prompt | ddim+p2p | |
| Null-text Inversion | Prompt-to-Prompt | null-text-inversion+p2p | |
| Negative-prompt Inversion | Prompt-to-Prompt | negative-prompt-inversion+p2p | |
| DirectInversion(Ours) | Prompt-to-Prompt | directinversion+p2p | |
| DirectInversion(Ours) (ablation: with various guidance scale) | Prompt-to-Prompt (ablation: with various guidance scale) | directinversion+p2p_guidance_{i}_{f} | For ablation study. {i} means inverse guidance scale, {f} means forward guidance scale. {i} could be chosen from \[0,1,25,5,75\]. {f} could be chosen from \[1,25,5,75\]. For example, directinversion+p2p_guidance_1_75 means inverse with gudiance scale 1.0, forward with 7.5. |
| Null-text Inversion | Proximal Guidance | null-text-inversion+proximal-guidance | |
| Negative-prompt Inversion | Proximal Guidance | negative-prompt-inversion+proximal-guidance | |
| Null-latent Inversion | Prompt-to-Prompt | ablation_null-latent-inversion+p2p | For ablation study. Edit the Null-text Inversion to Null-latent Inversion. |
| Null-Text Inversion (ablation: single branch) | Prompt-to-Prompt | ablation_null-text-inversion_single_branch+p2p | For ablation study. Edit the Null-text Inversion to exchange null embedding only in source branch. |
| DirectInversion(Ours) (ablation: add with scale) | Prompt-to-Prompt (ablation: add with scale) | ablation_directinversion_{s}+p2p | For ablation study. {s} means the added scale. {s} could be chosen from \[04,08\]. For example, ablation_directinversion_02+p2p means add with scale=0.2. |
| DirectInversion(Ours) (ablation: skip step) | Prompt-to-Prompt (ablation: skip step) | ablation_directinversion_interval_{s}+p2p | For ablation study. {s} means the skip step. {s} could be chosen from \[2,5,10,24,49\]. For example, ablation_directinversion_interval_2+p2p means skip every 2 steps. |
| DirectInversion(Ours) (ablation: add source offset for target latent) | Prompt-to-Prompt (ablation: add source offset for target latent) | ablation_directinversion_add-source+p2p | |
| DirectInversion(Ours) (ablation: add target offset for target latent) | Prompt-to-Prompt (ablation: add target offset for target latent) | ablation_directinversion_add-target+p2p | |
</details>
<details> <summary> run_editing_stylediffusion.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| StyleDiffusion | Prompt-to-Prompt | stylediffusion+p2p | |
</details>
<details> <summary> run_editing_edit_friendly_p2p.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| Edit Friendly Inversion | Prompt-to-Prompt | edit-friendly-inversion+p2p | |
</details>
<details> <summary> run_editing_masactrl.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| DDIM | MasaCtrl | ddim+masactrl | |
| DirectInversion(Ours) | MasaCtrl | directinversion+masactrl | |
</details>
<details> <summary> run_editing_pnp.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| DDIM | Plug-and-Play | ddim+pnp | |
| DirectInversion(Ours) | Plug-and-Play | directinversion+pnp | |
</details>
<details> <summary> run_editing_pnp.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| DDIM | Pix2Pix-Zero | ddim+pix2pix-zero | |
| DirectInversion(Ours) | Pix2Pix-Zero | directinversion+pix2pix-zero | |
</details>
<details> <summary> run_editing_edict.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| EDICT | | edict+direct_forward | |
</details>
<details> <summary> run_editing_instructdiffusion.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| | InstructDiffusion | instruct-diffusion | |
</details>
<details> <summary> run_editing_instructpix2pix.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| | Instruct-Pix2Pix | instruct-pix2pix | |
</details>
<details> <summary> run_editing_blended_latent_diffusion.py </summary>
| Inversion Method | Editing Method | Index | Explanation
| :-----: | :----: | :----: | :----: |
| | Blended Latent Diffusion | blended-latent-diffusion | |
</details>
For example, if you want to run DirectInversion(Ours) + Prompt-to-Prompt, you can find this method has an index `directinversion+p2p` in `run_editing_p2p.py`. Then, you can run the editing type 0 with DirectInversion(Ours) + Prompt-to-Prompt through:
```
python run_editing_p2p.py --output_path output --edit_category_list 0 --edit_method_list directinversion+p2p
```
You can also run multiple editing methods and multi editing type with:
```
python run_editing_p2p.py --edit_category_list 0 1 2 3 4 5 6 7 8 9 --edit_method_list directinversion+p2p null-text+p2p
```
You can also specify --rerun_exist_images to choose whether rerun exist images. You can also specify --data_path and --output for image path and output path.
**Run Any Image**
You can process your own images and editing prompts to the same format as our given benchmark to run large number of images. You can also edit the given python file to your own image. We have given out the edited python file of `run_editing_p2p.py` as `run_editing_p2p_one_image.py`. You can run one image's editing through:
```shell
python -u run_editing_p2p_one_image.py --image_path scripts/example_cake.jpg --original_prompt "a round cake with orange frosting on a wooden plate" --editing_prompt "a square cake with orange frosting on a wooden plate" --blended_word "cake cake" --output_path "directinversion+p2p.jpg" "ddim+p2p.jpg" --edit_method_list "directinversion+p2p" "ddim+p2p"
```
We also provide jupyter notebook demo `run_editing_p2p_one_image.ipynb`.
Noted that we use default parameters in our code. However, it is not optimal for all images. You may ajust them based on your inputs.
### Evaluation 📐
<span id="evaluation"></span>
You can run evaluation through:
```
python evaluation/evaluate.py --metrics "structure_distance" "psnr_unedit_part" "lpips_unedit_part" "mse_unedit_part" "ssim_unedit_part" "clip_similarity_source_image" "clip_similarity_target_image" "clip_similarity_target_image_edit_part" --result_path evaluation_result.csv --edit_category_list 0 1 2 3 4 5 6 7 8 9 --tgt_methods 1_ddim+p2p 1_directinversion+p2p
```
You can find the choice of tgt_methods in `evaluation/evaluate.py` with the dict "all_tgt_image_folders".
All the results of editing are avaible for download at [here](https://drive.google.com/drive/folders/1hy8QTiaOZllKmwn6-vwWmHOpRP3uX-Ji?usp=sharing). You can download them and put them with file structre as follows to reproduce all the results in our paper.
```
output
|-- ddim+p2p
|-- annotation_images
|-- ...
|-- directinversion+p2p
|-- annotation_images
|-- ...
...
```
If you want to evaluate the whole table's results shown in our paper, you can run:
```
python evaluation/evaluate.py --metrics "structure_distance" "psnr_unedit_part" "lpips_unedit_part" "mse_unedit_part" "ssim_unedit_part" "clip_similarity_source_image" "clip_similarity_target_image" "clip_similarity_target_image_edit_part" --result_path evaluation_result.csv --edit_category_list 0 1 2 3 4 5 6 7 8 9 --tgt_methods 1 --evaluate_whole_table
```
Then, all results in the table 1 will be output in evaluation_result.csv.
## 🥇 Quantitative Results
<span id="quantitative-results"></span>
Compare PnP Inversion with other inversion techniques across various editing methods:

More results can be found in the main paper.
## 🌟 Qualitative Results
<span id="qualitative-results"></span>
Performance enhancement of incorporating PnP Inversion into four diffusion-based
editing methods:

Visulization results of different inversion and editing techniques:

More results can be found in the main paper.
## 🤝🏼 Cite Us
<span id="cite-us"></span>
```
@article{ju2023direct,
title={PnP Inversion: Boosting Diffusion-based Editing with 3 Lines of Code},
author={Ju, Xuan and Zeng, Ailing and Bian, Yuxuan and Liu, Shaoteng and Xu, Qiang},
journal={International Conference on Learning Representations ({ICLR})},
year={2024}
}
```
## 💖 Acknowledgement
<span id="acknowledgement"></span>
Our code is modified on the basis of [prompt-to-prompt](https://github.com/google/prompt-to-prompt), [StyleDiffusion](https://github.com/sen-mao/StyleDiffusion), [MasaCtrl](https://github.com/TencentARC/MasaCtrl), [pix2pix-zero](https://github.com/pix2pixzero/pix2pix-zero) , [Plug-and-Play](https://github.com/MichalGeyer/plug-and-play), [Edit Friendly DDPM Noise Space](https://github.com/inbarhub/DDPM_inversion), [Blended Latent Diffusion](https://github.com/omriav/blended-latent-diffusion), [Proximal Guidance](https://github.com/phymhan/prompt-to-prompt), [InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix), thanks to all the contributors!
================================================
FILE: environment/edict_requirements.txt
================================================
diffusers==0.6.0
transformers==4.19.2
matplotlib
omegaconf
imageio
================================================
FILE: environment/instructdiffusion_requirements.txt
================================================
einops==0.6.1
taming-transformers-rom1504==0.0.6
omegaconf==2.3.0
k-diffusion==0.0.16
deepspeed==0.10.2
timm==0.9.7
transformers==4.33.1
matplotlib
================================================
FILE: environment/masactrl_requirements.txt
================================================
diffusers==0.15.0
transformers
opencv-python
einops
omegaconf
pytorch_lightning
matplotlib
================================================
FILE: environment/p2p_requirements.txt
================================================
diffusers==0.10.0
transformers
ftfy
opencv-python
ipywidgets
matplotlib
accelerate
================================================
FILE: environment/pix2pix_zero_requirements.txt
================================================
diffusers==0.14.0
matplotlib
salesforce-lavis
================================================
FILE: environment/pnp_requirements.txt
================================================
diffusers==0.17.1
xformers==0.0.20
transformers==4.30.2
accelerate==0.20.3
matplotlib
salesforce-lavis
================================================
FILE: evaluation/evaluate.py
================================================
import json
import argparse
import os
import numpy as np
from PIL import Image
import csv
from evaluation.matrics_calculator import MetricsCalculator
def mask_decode(encoded_mask,image_shape=[512,512]):
length=image_shape[0]*image_shape[1]
mask_array=np.zeros((length,))
for i in range(0,len(encoded_mask),2):
splice_len=min(encoded_mask[i+1],length-encoded_mask[i])
for j in range(splice_len):
mask_array[encoded_mask[i]+j]=1
mask_array=mask_array.reshape(image_shape[0], image_shape[1])
# to avoid annotation errors in boundary
mask_array[0,:]=1
mask_array[-1,:]=1
mask_array[:,0]=1
mask_array[:,-1]=1
return mask_array
def calculate_metric(metrics_calculator,metric, src_image, tgt_image, src_mask, tgt_mask,src_prompt,tgt_prompt):
if metric=="psnr":
return metrics_calculator.calculate_psnr(src_image, tgt_image, None, None)
if metric=="lpips":
return metrics_calculator.calculate_lpips(src_image, tgt_image, None, None)
if metric=="mse":
return metrics_calculator.calculate_mse(src_image, tgt_image, None, None)
if metric=="ssim":
return metrics_calculator.calculate_ssim(src_image, tgt_image, None, None)
if metric=="structure_distance":
return metrics_calculator.calculate_structure_distance(src_image, tgt_image, None, None)
if metric=="psnr_unedit_part":
if (1-src_mask).sum()==0 or (1-tgt_mask).sum()==0:
return "nan"
else:
return metrics_calculator.calculate_psnr(src_image, tgt_image, 1-src_mask, 1-tgt_mask)
if metric=="lpips_unedit_part":
if (1-src_mask).sum()==0 or (1-tgt_mask).sum()==0:
return "nan"
else:
return metrics_calculator.calculate_lpips(src_image, tgt_image, 1-src_mask, 1-tgt_mask)
if metric=="mse_unedit_part":
if (1-src_mask).sum()==0 or (1-tgt_mask).sum()==0:
return "nan"
else:
return metrics_calculator.calculate_mse(src_image, tgt_image, 1-src_mask, 1-tgt_mask)
if metric=="ssim_unedit_part":
if (1-src_mask).sum()==0 or (1-tgt_mask).sum()==0:
return "nan"
else:
return metrics_calculator.calculate_ssim(src_image, tgt_image, 1-src_mask, 1-tgt_mask)
if metric=="structure_distance_unedit_part":
if (1-src_mask).sum()==0 or (1-tgt_mask).sum()==0:
return "nan"
else:
return metrics_calculator.calculate_structure_distance(src_image, tgt_image, 1-src_mask, 1-tgt_mask)
if metric=="psnr_edit_part":
if src_mask.sum()==0 or tgt_mask.sum()==0:
return "nan"
else:
return metrics_calculator.calculate_psnr(src_image, tgt_image, src_mask, tgt_mask)
if metric=="lpips_edit_part":
if src_mask.sum()==0 or tgt_mask.sum()==0:
return "nan"
else:
return metrics_calculator.calculate_lpips(src_image, tgt_image, src_mask, tgt_mask)
if metric=="mse_edit_part":
if src_mask.sum()==0 or tgt_mask.sum()==0:
return "nan"
else:
return metrics_calculator.calculate_mse(src_image, tgt_image, src_mask, tgt_mask)
if metric=="ssim_edit_part":
if src_mask.sum()==0 or tgt_mask.sum()==0:
return "nan"
else:
return metrics_calculator.calculate_ssim(src_image, tgt_image, src_mask, tgt_mask)
if metric=="structure_distance_edit_part":
if src_mask.sum()==0 or tgt_mask.sum()==0:
return "nan"
else:
return metrics_calculator.calculate_structure_distance(src_image, tgt_image, src_mask, tgt_mask)
if metric=="clip_similarity_source_image":
return metrics_calculator.calculate_clip_similarity(src_image, src_prompt,None)
if metric=="clip_similarity_target_image":
return metrics_calculator.calculate_clip_similarity(tgt_image, tgt_prompt,None)
if metric=="clip_similarity_target_image_edit_part":
if tgt_mask.sum()==0:
return "nan"
else:
return metrics_calculator.calculate_clip_similarity(tgt_image, tgt_prompt,tgt_mask)
all_tgt_image_folders={
# results of comparing inversion
# ---
"1_ddim+p2p":"output/ddim+p2p/annotation_images",
"1_null-text-inversion+p2p_a800":"output/null-text-inversion+p2p_a800/annotation_images",
"1_null-text-inversion+p2p_3090":"output/null-text-inversion+p2p_3090/annotation_images",
"1_negative-prompt-inversion+p2p":"output/negative-prompt-inversion+p2p/annotation_images",
"1_stylediffusion+p2p":"output/stylediffusion+p2p/annotation_images",
"1_directinversion+p2p":"output/directinversion+p2p/annotation_images",
# ---
"1_ddim+masactrl":"output/ddim+masactrl/annotation_images",
"1_directinversion+masactrl":"output/directinversion+masactrl/annotation_images",
# ---
"1_ddim+pix2pix-zero":"output/ddim+pix2pix-zero/annotation_images",
"1_directinversion+pix2pix-zero":"output/directinversion+pix2pix-zero/annotation_images",
# ---
"1_ddim+pnp":"output/ddim+pnp/annotation_images",
"1_directinversion+pnp":"output/directinversion+pnp/annotation_images",
# ---
# results of comparing model-based methods
"2_instruct-pix2pix":"output/instruct-pix2pix/annotation_images",
"2_instruct-diffusion":"output/instruct-diffusion/annotation_images",
"2_blended-latent-diffusion":"output/blended-latent-diffusion/annotation_images",
"2_directinversion+p2p":"output/directinversion+p2p/annotation_images",
# results of different inversion/forward guidance scale
"3_directinversion+p2p_guidance_0_1":"output/directinversion+p2p_guidance_0_1/annotation_images",
"3_directinversion+p2p_guidance_0_5":"output/directinversion+p2p_guidance_0_5/annotation_images",
"3_directinversion+p2p_guidance_0_25":"output/directinversion+p2p_guidance_0_25/annotation_images",
"3_directinversion+p2p_guidance_0_75":"output/directinversion+p2p_guidance_0_75/annotation_images",
"3_directinversion+p2p_guidance_1_1":"output/directinversion+p2p_guidance_1_1/annotation_images",
"3_directinversion+p2p_guidance_1_5":"output/directinversion+p2p_guidance_1_5/annotation_images",
"3_directinversion+p2p_guidance_1_25":"output/directinversion+p2p_guidance_1_25/annotation_images",
"3_directinversion+p2p_guidance_1_75":"output/directinversion+p2p_guidance_1_75/annotation_images",
"3_directinversion+p2p_guidance_25_1":"output/directinversion+p2p_guidance_25_1/annotation_images",
"3_directinversion+p2p_guidance_25_5":"output/directinversion+p2p_guidance_25_5/annotation_images",
"3_directinversion+p2p_guidance_25_25":"output/directinversion+p2p_guidance_25_25/annotation_images",
"3_directinversion+p2p_guidance_25_75":"output/directinversion+p2p_guidance_25_75/annotation_images",
"3_directinversion+p2p_guidance_5_1":"output/directinversion+p2p_guidance_5_1/annotation_images",
"3_directinversion+p2p_guidance_5_5":"output/directinversion+p2p_guidance_5_5/annotation_images",
"3_directinversion+p2p_guidance_5_25":"output/directinversion+p2p_guidance_5_25/annotation_images",
"3_directinversion+p2p_guidance_5_75":"output/directinversion+p2p_guidance_5_75/annotation_images",
"3_directinversion+p2p_guidance_75_1":"output/directinversion+p2p_guidance_75_1/annotation_images",
"3_directinversion+p2p_guidance_75_5":"output/directinversion+p2p_guidance_75_5/annotation_images",
"3_directinversion+p2p_guidance_75_25":"output/directinversion+p2p_guidance_75_25/annotation_images",
"3_directinversion+p2p_guidance_75_75":"output/directinversion+p2p_guidance_75_75/annotation_images",
# results of background preservation method
"4_null-text-inverse+p2p_a800":"output/null-text-inversion+p2p_a800/annotation_images",
"4_null-text-inverse+p2p_3090":"output/null-text-inversion+p2p_3090/annotation_images",
"4_null-text-inversion+proximal-guidance":"output/null-text-inversion+proximal-guidance/annotation_images",
"4_negative-prompt-inversion+proximal-guidance":"output/negative-prompt-inversion+proximal-guidance/annotation_images",
"4_edit-friendly-inversion+p2p":"output/edit-friendly-inversion+p2p/annotation_images",
"4_edict+direct_forward":"output/edict+direct_forward/annotation_images",
"4_edict+p2p":"output/edict+p2p/annotation_images",
"4_directinversion+p2p":"output/directinversion+p2p/annotation_images",
# ablation results of contrast null-text-inversion with directinversion
"5_ablation_directinversion_04+p2p":"output/ablation_directinversion_04+p2p/annotation_images",
"5_ablation_directinversion_08+p2p":"output/ablation_directinversion_08+p2p/annotation_images",
"5_ablation_null-latent-inversion+p2p_a800":"output/ablation_null-latent-inversion+p2p_a800/annotation_images",
"5_ablation_null-latent-inversion+p2p_3090":"output/ablation_null-latent-inversion+p2p_3090/annotation_images",
"5_ablation_null-text-inversion_single_branch+p2p_a800":"output/ablation_null-text-inversion_single_branch+p2p_a800/annotation_images",
"5_ablation_null-text-inversion_single_branch+p2p_3090":"output/ablation_null-text-inversion_single_branch+p2p_3090/annotation_images",
# ablation results of different intervals
"6_ablation_directinversion_interval_2":"output/ablation_directinversion_interval_2+p2p/annotation_images",
"6_ablation_directinversion_interval_5":"output/ablation_directinversion_interval_5+p2p/annotation_images",
"6_ablation_directinversion_interval_10":"output/ablation_directinversion_interval_10+p2p/annotation_images",
"6_ablation_directinversion_interval_24":"output/ablation_directinversion_interval_24+p2p/annotation_images",
"6_ablation_directinversion_interval_49":"output/ablation_directinversion_interval_49+p2p/annotation_images",
# ablation results of different steps
"7_ablation_directinversion_step_20":"output/ablation_directinversion_step_20+p2p/annotation_images",
"7_ablation_directinversion_step_100":"output/ablation_directinversion_step_100+p2p/annotation_images",
"7_ablation_directinversion_step_500":"output/ablation_directinversion_step_500+p2p/annotation_images",
# ablation results of add target latent
"8_ablation_directinversion_add-source+p2p":"output/ablation_directinversion_add-source+p2p/annotation_images",
"8_ablation_directinversion_add-target+p2p":"output/ablation_directinversion_add-target+p2p/annotation_images",
}
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--annotation_mapping_file', type=str, default="data/mapping_file.json")
parser.add_argument('--metrics', nargs = '+', type=str, default=[
"structure_distance",
"psnr_unedit_part",
"lpips_unedit_part",
"mse_unedit_part",
"ssim_unedit_part",
"clip_similarity_source_image",
"clip_similarity_target_image",
"clip_similarity_target_image_edit_part",
])
parser.add_argument('--src_image_folder', type=str, default="data/annotation_images")
parser.add_argument('--tgt_methods', nargs = '+', type=str, default=[
"1_ddim+p2p", "1_null-text-inversion+p2p_a800",
"1_null-text-inversion+p2p_3090", "1_negative-prompt-inversion+p2p",
"1_stylediffusion+p2p", "1_directinversion+p2p",
])
parser.add_argument('--result_path', type=str, default="evaluation_result.csv")
parser.add_argument('--device', type=str, default="cuda")
parser.add_argument('--edit_category_list', nargs = '+', type=str, default=[
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9"
]) # the editing category that needed to run
parser.add_argument('--evaluate_whole_table', action= "store_true") # rerun existing images
args = parser.parse_args()
annotation_mapping_file=args.annotation_mapping_file
metrics=args.metrics
src_image_folder=args.src_image_folder
tgt_methods=args.tgt_methods
edit_category_list=args.edit_category_list
evaluate_whole_table=args.evaluate_whole_table
tgt_image_folders={}
if evaluate_whole_table:
for key in all_tgt_image_folders:
if key[0] in tgt_methods:
tgt_image_folders[key]=all_tgt_image_folders[key]
else:
for key in tgt_methods:
tgt_image_folders[key]=all_tgt_image_folders[key]
result_path=args.result_path
metrics_calculator=MetricsCalculator(args.device)
with open(result_path,'w',newline="") as f:
csv_write = csv.writer(f)
csv_head=[]
for tgt_image_folder_key,_ in tgt_image_folders.items():
for metric in metrics:
csv_head.append(f"{tgt_image_folder_key}|{metric}")
data_row = ["file_id"]+csv_head
csv_write.writerow(data_row)
with open(annotation_mapping_file,"r") as f:
annotation_file=json.load(f)
for key, item in annotation_file.items():
if item["editing_type_id"] not in edit_category_list:
continue
print(f"evaluating image {key} ...")
base_image_path=item["image_path"]
mask=mask_decode(item["mask"])
original_prompt = item["original_prompt"].replace("[", "").replace("]", "")
editing_prompt = item["editing_prompt"].replace("[", "").replace("]", "")
mask=mask[:,:,np.newaxis].repeat([3],axis=2)
src_image_path=os.path.join(src_image_folder, base_image_path)
src_image = Image.open(src_image_path)
evaluation_result=[key]
for tgt_image_folder_key,tgt_image_folder in tgt_image_folders.items():
tgt_image_path=os.path.join(tgt_image_folder, base_image_path)
print(f"evluating method: {tgt_image_folder_key}")
tgt_image = Image.open(tgt_image_path)
if tgt_image.size[0] != tgt_image.size[1]:
# to evaluate editing
tgt_image = tgt_image.crop((tgt_image.size[0]-512,tgt_image.size[1]-512,tgt_image.size[0],tgt_image.size[1]))
# to evaluate reconstruction
# tgt_image = tgt_image.crop((tgt_image.size[0]-512*2,tgt_image.size[1]-512,tgt_image.size[0]-512,tgt_image.size[1]))
for metric in metrics:
print(f"evluating metric: {metric}")
evaluation_result.append(calculate_metric(metrics_calculator,metric,src_image, tgt_image, mask, mask, original_prompt, editing_prompt))
with open(result_path,'a+',newline="") as f:
csv_write = csv.writer(f)
csv_write.writerow(evaluation_result)
================================================
FILE: evaluation/matrics_calculator.py
================================================
import torch
from torchvision.transforms import Resize
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from torchmetrics.multimodal import CLIPScore
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.regression import MeanSquaredError
class VitExtractor:
BLOCK_KEY = 'block'
ATTN_KEY = 'attn'
PATCH_IMD_KEY = 'patch_imd'
QKV_KEY = 'qkv'
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
def __init__(self, model_name, device):
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
self.model.eval()
self.model_name = model_name
self.hook_handlers = []
self.layers_dict = {}
self.outputs_dict = {}
for key in VitExtractor.KEY_LIST:
self.layers_dict[key] = []
self.outputs_dict[key] = []
self._init_hooks_data()
self.device=device
def _init_hooks_data(self):
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
for key in VitExtractor.KEY_LIST:
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
self.outputs_dict[key] = []
def _register_hooks(self, **kwargs):
for block_idx, block in enumerate(self.model.blocks):
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
def _clear_hooks(self):
for handler in self.hook_handlers:
handler.remove()
self.hook_handlers = []
def _get_block_hook(self):
def _get_block_output(model, input, output):
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
return _get_block_output
def _get_attn_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
return _get_attn_output
def _get_qkv_hook(self):
def _get_qkv_output(model, inp, output):
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
return _get_qkv_output
# TODO: CHECK ATTN OUTPUT TUPLE
def _get_patch_imd_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
return _get_attn_output
def get_feature_from_input(self, input_img): # List([B, N, D])
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_qkv_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.QKV_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_attn_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_patch_size(self):
return 8 if "8" in self.model_name else 16
def get_width_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return w // patch_size
def get_height_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return h // patch_size
def get_patch_num(self, input_img_shape):
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
return patch_num
def get_head_num(self):
if "dino" in self.model_name:
return 6 if "s" in self.model_name else 12
return 6 if "small" in self.model_name else 12
def get_embedding_dim(self):
if "dino" in self.model_name:
return 384 if "s" in self.model_name else 768
return 384 if "small" in self.model_name else 768
def get_queries_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
return q
def get_keys_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
return k
def get_values_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
return v
def get_keys_from_input(self, input_img, layer_num):
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
return keys
def get_keys_self_sim_from_input(self, input_img, layer_num):
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
h, t, d = keys.shape
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
ssim_map = self.attn_cosine_sim(concatenated_keys[None, None, ...])
return ssim_map
def attn_cosine_sim(self,x, eps=1e-08):
x = x[0] # TEMP: getting rid of redundant dimension, TBF
norm1 = x.norm(dim=2, keepdim=True)
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
return sim_matrix
class LossG(torch.nn.Module):
def __init__(self, cfg,device):
super().__init__()
self.cfg = cfg
self.device=device
self.extractor = VitExtractor(model_name=cfg['dino_model_name'], device=device)
imagenet_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
global_resize_transform = Resize(cfg['dino_global_patch_size'], max_size=480)
self.global_transform = transforms.Compose([global_resize_transform,
imagenet_norm
])
self.lambdas = dict(
lambda_global_cls=cfg['lambda_global_cls'],
lambda_global_ssim=0,
lambda_entire_ssim=0,
lambda_entire_cls=0,
lambda_global_identity=0
)
def update_lambda_config(self, step):
if step == self.cfg['cls_warmup']:
self.lambdas['lambda_global_ssim'] = self.cfg['lambda_global_ssim']
self.lambdas['lambda_global_identity'] = self.cfg['lambda_global_identity']
if step % self.cfg['entire_A_every'] == 0:
self.lambdas['lambda_entire_ssim'] = self.cfg['lambda_entire_ssim']
self.lambdas['lambda_entire_cls'] = self.cfg['lambda_entire_cls']
else:
self.lambdas['lambda_entire_ssim'] = 0
self.lambdas['lambda_entire_cls'] = 0
def forward(self, outputs, inputs):
self.update_lambda_config(inputs['step'])
losses = {}
loss_G = 0
if self.lambdas['lambda_global_ssim'] > 0:
losses['loss_global_ssim'] = self.calculate_global_ssim_loss(outputs['x_global'], inputs['A_global'])
loss_G += losses['loss_global_ssim'] * self.lambdas['lambda_global_ssim']
if self.lambdas['lambda_entire_ssim'] > 0:
losses['loss_entire_ssim'] = self.calculate_global_ssim_loss(outputs['x_entire'], inputs['A'])
loss_G += losses['loss_entire_ssim'] * self.lambdas['lambda_entire_ssim']
if self.lambdas['lambda_entire_cls'] > 0:
losses['loss_entire_cls'] = self.calculate_crop_cls_loss(outputs['x_entire'], inputs['B_global'])
loss_G += losses['loss_entire_cls'] * self.lambdas['lambda_entire_cls']
if self.lambdas['lambda_global_cls'] > 0:
losses['loss_global_cls'] = self.calculate_crop_cls_loss(outputs['x_global'], inputs['B_global'])
loss_G += losses['loss_global_cls'] * self.lambdas['lambda_global_cls']
if self.lambdas['lambda_global_identity'] > 0:
losses['loss_global_id_B'] = self.calculate_global_id_loss(outputs['y_global'], inputs['B_global'])
loss_G += losses['loss_global_id_B'] * self.lambdas['lambda_global_identity']
losses['loss'] = loss_G
return losses
def calculate_global_ssim_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs): # avoid memory limitations
a = self.global_transform(a)
b = self.global_transform(b)
with torch.no_grad():
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
return loss
def calculate_crop_cls_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(outputs, inputs): # avoid memory limitations
a = self.global_transform(a).unsqueeze(0).to(self.device)
b = self.global_transform(b).unsqueeze(0).to(self.device)
cls_token = self.extractor.get_feature_from_input(a)[-1][0, 0, :]
with torch.no_grad():
target_cls_token = self.extractor.get_feature_from_input(b)[-1][0, 0, :]
loss += F.mse_loss(cls_token, target_cls_token)
return loss
def calculate_global_id_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs):
a = self.global_transform(a)
b = self.global_transform(b)
with torch.no_grad():
keys_a = self.extractor.get_keys_from_input(a.unsqueeze(0), 11)
keys_b = self.extractor.get_keys_from_input(b.unsqueeze(0), 11)
loss += F.mse_loss(keys_a, keys_b)
return loss
class MetricsCalculator:
def __init__(self, device) -> None:
self.device=device
self.clip_metric_calculator = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
self.psnr_metric_calculator = PeakSignalNoiseRatio(data_range=1.0).to(device)
self.lpips_metric_calculator = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to(device)
self.mse_metric_calculator = MeanSquaredError().to(device)
self.ssim_metric_calculator = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
self.structure_distance_metric_calculator = LossG(cfg={
'dino_model_name': 'dino_vitb8', # ['dino_vitb8', 'dino_vits8', 'dino_vitb16', 'dino_vits16']
'dino_global_patch_size': 224,
'lambda_global_cls': 10.0,
'lambda_global_ssim': 1.0,
'lambda_global_identity': 1.0,
'entire_A_every':75,
'lambda_entire_cls':10,
'lambda_entire_ssim':1.0
},device=device)
def calculate_clip_similarity(self, img, txt, mask=None):
img = np.array(img)
if mask is not None:
mask = np.array(mask)
img = np.uint8(img * mask)
img_tensor=torch.tensor(img).permute(2,0,1).to(self.device)
score = self.clip_metric_calculator(img_tensor, txt)
score = score.cpu().item()
return score
def calculate_psnr(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
img_pred = np.array(img_pred).astype(np.float32)/255
img_gt = np.array(img_gt).astype(np.float32)/255
assert img_pred.shape == img_gt.shape, "Image shapes should be the same."
if mask_pred is not None:
mask_pred = np.array(mask_pred).astype(np.float32)
img_pred = img_pred * mask_pred
if mask_gt is not None:
mask_gt = np.array(mask_gt).astype(np.float32)
img_gt = img_gt * mask_gt
img_pred_tensor=torch.tensor(img_pred).permute(2,0,1).unsqueeze(0).to(self.device)
img_gt_tensor=torch.tensor(img_gt).permute(2,0,1).unsqueeze(0).to(self.device)
score = self.psnr_metric_calculator(img_pred_tensor,img_gt_tensor)
score = score.cpu().item()
return score
def calculate_lpips(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
img_pred = np.array(img_pred).astype(np.float32)/255
img_gt = np.array(img_gt).astype(np.float32)/255
assert img_pred.shape == img_gt.shape, "Image shapes should be the same."
if mask_pred is not None:
mask_pred = np.array(mask_pred).astype(np.float32)
img_pred = img_pred * mask_pred
if mask_gt is not None:
mask_gt = np.array(mask_gt).astype(np.float32)
img_gt = img_gt * mask_gt
img_pred_tensor=torch.tensor(img_pred).permute(2,0,1).unsqueeze(0).to(self.device)
img_gt_tensor=torch.tensor(img_gt).permute(2,0,1).unsqueeze(0).to(self.device)
score = self.lpips_metric_calculator(img_pred_tensor*2-1,img_gt_tensor*2-1)
score = score.cpu().item()
return score
def calculate_mse(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
img_pred = np.array(img_pred).astype(np.float32)/255
img_gt = np.array(img_gt).astype(np.float32)/255
assert img_pred.shape == img_gt.shape, "Image shapes should be the same."
if mask_pred is not None:
mask_pred = np.array(mask_pred).astype(np.float32)
img_pred = img_pred * mask_pred
if mask_gt is not None:
mask_gt = np.array(mask_gt).astype(np.float32)
img_gt = img_gt * mask_gt
img_pred_tensor=torch.tensor(img_pred).permute(2,0,1).to(self.device)
img_gt_tensor=torch.tensor(img_gt).permute(2,0,1).to(self.device)
score = self.mse_metric_calculator(img_pred_tensor.contiguous(),img_gt_tensor.contiguous())
score = score.cpu().item()
return score
def calculate_ssim(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
img_pred = np.array(img_pred).astype(np.float32)/255
img_gt = np.array(img_gt).astype(np.float32)/255
assert img_pred.shape == img_gt.shape, "Image shapes should be the same."
if mask_pred is not None:
mask_pred = np.array(mask_pred).astype(np.float32)
img_pred = img_pred * mask_pred
if mask_gt is not None:
mask_gt = np.array(mask_gt).astype(np.float32)
img_gt = img_gt * mask_gt
img_pred_tensor=torch.tensor(img_pred).permute(2,0,1).unsqueeze(0).to(self.device)
img_gt_tensor=torch.tensor(img_gt).permute(2,0,1).unsqueeze(0).to(self.device)
score = self.ssim_metric_calculator(img_pred_tensor,img_gt_tensor)
score = score.cpu().item()
return score
def calculate_structure_distance(self, img_pred, img_gt, mask_pred=None, mask_gt=None, use_gpu = True):
img_pred = np.array(img_pred).astype(np.float32)
img_gt = np.array(img_gt).astype(np.float32)
assert img_pred.shape == img_gt.shape, "Image shapes should be the same."
if mask_pred is not None:
mask_pred = np.array(mask_pred).astype(np.float32)
img_pred = img_pred * mask_pred
if mask_gt is not None:
mask_gt = np.array(mask_gt).astype(np.float32)
img_gt = img_gt * mask_gt
img_pred = torch.from_numpy(np.transpose(img_pred, axes=(2, 0, 1))).to(self.device)
img_gt = torch.from_numpy(np.transpose(img_gt, axes=(2, 0, 1))).to(self.device)
img_pred = torch.unsqueeze(img_pred, 0)
img_gt = torch.unsqueeze(img_gt, 0)
structure_distance = self.structure_distance_metric_calculator.calculate_global_ssim_loss(img_gt, img_pred)
return structure_distance.data.cpu().numpy()
================================================
FILE: models/InstructDiffusion/.gitignore
================================================
data/
checkpoints/
stable_diffusion/models/ldm/stable-diffusion-v1/
src/
logs/
cache/
imgs/*
work_dirs/*
wandb/*
DeepSpeed/*
inference_submit_sing_alpha.yaml
inference_submit_sing.yaml
inference_submit.yaml
train_submit.yaml
.amltconfig
.amltignore
test.sh
debug.yaml
post-process/
test_single.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
inference_submit_sing_grefcoco.yaml
================================================
FILE: models/InstructDiffusion/LICENSE
================================================
Copyright 2023 Authors of InstructDiffusion(https://arxiv.org/pdf/2309.03895.pdf)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Portions of code and models (such as pretrained checkpoints, which are fine-tuned starting from released Stable Diffusion checkpoints) are derived from the Stable Diffusion codebase (https://github.com/CompVis/stable-diffusion) and Instruct-pix2pix codebase (https://github.com/timothybrooks/instruct-pix2pix). Further restrictions may apply. Please consult the Stable Diffusion license `stable_diffusion/LICENSE` and the Instruct-pix2pix license `instruct-pix2pix/LICENSE`. Modified code is denoted as such in comments at the start of each file.
================================================
FILE: models/InstructDiffusion/README.md
================================================
# InstructDiffusion: A Generalist Modeling Interface for Vision Tasks
<p align="center">
<a href="https://gengzigang.github.io/instructdiffusion.github.io/">Project Page</a> |
<a href="https://arxiv.org/pdf/2309.03895.pdf">Arxiv</a> |
<a href="https://e0448e59d09dbe092f.gradio.live">Web Demo</a> |
<a href="#QuickStart">QuickStart</a> |
<a href="#Training">Training</a> |
<a href="#Acknowledge">Acknowledge</a> |
<a href='#Citation'>Citation</a>
</p>
<div align="center">
<img src="figure/teaser.png" width="1000"/>
</div>
This is the pytorch implementation of InstructDiffusion, a unifying and generic framework for aligning computer vision tasks with human instructions. Our code is based on the [Instruct-pix2pix](https://github.com/timothybrooks/instruct-pix2pix) and [CompVis/stable_diffusion](https://github.com/CompVis/stable-diffusion).<br>
## QuickStart
Follow the steps below to quickly edit your own images. The inference code in our repository requires **one GPU with > 9GB memory** to test images with a resolution of **512**.
1. Clone this repo.
2. Setup conda environment:
```
conda env create -f environment.yaml
conda activate instructdiff
```
3. We provide a well-trained [checkpoint](https://mailustceducn-my.sharepoint.com/:u:/g/personal/aa397601_mail_ustc_edu_cn/EZmXduulFidIhJD73SGcbOoBNpm18CJmU4PgPTS21RM2Ow?e=KqQYpO) and a [checkpoint](https://mailustceducn-my.sharepoint.com/:u:/g/personal/aa397601_mail_ustc_edu_cn/EWlNmyeS9P1BkRg_IlXbPbwBeNMQXQTcIA0pCokyd61UWg?e=iKfRdk) that has undergone human-alignment. Feel free to download to the folder `checkpoints` and try both of them.
4. You can edit your own images:
```bash
python edit_cli.py --input example.jpg --edit "Transform it to van Gogh, starry night style."
# Optionally, you can customize the parameters by using the following syntax:
# --resolution 512 --steps 50 --config configs/instruct_diffusion.yaml --ckpt YOUR_CHECKPOINT --cfg-text 3.5 --cfg-image 1.25
# We also support loading image from the website and edit, e.g., you could run the command like this:
python edit_cli.py --input "https://wallup.net/wp-content/uploads/2016/01/207131-animals-nature-lion.jpg" \
--edit "Transform it to van Gogh, starry night style." \
--resolution 512 --steps 50 \
--config configs/instruct_diffusion.yaml \
--ckpt checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign.ckpt \
--outdir logs/
```
For other different tasks, we provide recommended parameter settings, which can be found in [`scripts/inference_example.sh`](./scripts/inference_example.sh).
5. (Optional) You can launch your own interactive editing Gradio app:
```bash
python edit_app.py
# You can also specify the path to the checkpoint
# The default checkpoint is checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign.ckpt
python edit_app.py --ckpt checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign.ckpt
```
## Training
The code is developed using python 3.8 on Ubuntu 18.04. The code is developed and tested using 48 NVIDIA V100 GPU cards, each with 32GB of memory. Other platforms are not fully tested.
### Installation
1. Clone this repo.
2. Setup conda environment:
```
conda env create -f environment.yaml
conda activate instructdiff
```
### Pre-trained Model Preparation
You can use the following command to download the official pre-trained stable diffusion model, or you can download the model trained by our pretraining adaptation process from [OneDrive](https://mailustceducn-my.sharepoint.com/:u:/g/personal/aa397601_mail_ustc_edu_cn/EXJSMIpFev5Nj0kuKI88U1IBZDSjegp3G8ukku0OxRRjFQ?e=QhnnB4) and put it into the following folder: stable_diffusion/models/ldm/stable-diffusion-v1/.
```
bash scripts/download_pretrained_sd.sh
```
### Data Preparation
You can refer to the [dataset](https://github.com/Gengzigang/InstructDiffusion/tree/main/dataset) to prepare your data.
### Training Command
For multi-GPU training on a single machine, you can use the following command:
```
python -m torch.distributed.launch --nproc_per_node=8 main.py --name v0 --base configs/instruct_diffusion.yaml --train --logdir logs/instruct_diffusion
```
For multi-GPU training on multiple machines, you can use the following command (assuming 6 machines as an example):
```
bash run_multinode.sh instruct_diffusion v0 6
```
### Convert EMA-Model
You can get the final EMA checkpoint for inference using the command below:
```
python convert_ckpt.py --ema-ckpt logs/instruct_diffusion/checkpoint/ckpt_epoch_200/state.pth --out-ckpt checkpoints/v1-5-pruned-emaonly-adaption-task.ckpt
```
## Acknowledge
Thanks to
- [Stable-diffusion](https://github.com/CompVis/stable-diffusion)
- [Instruct-pix2pix](https://github.com/timothybrooks/instruct-pix2pix)
## Citation
```
@inproceedings{Geng23instructdiff,
author={Zigang Geng, Binxin Yang, Tiankai Hang, Chen Li, Shuyang Gu, Ting Zhang, Jianmin Bao, Zheng Zhang, Han Hu, Dong Chen, Baining Guo},
title={InstructDiffusion: A Generalist Modeling Interface for Vision Tasks},
booktitle={{Arxiv}},
year={2023},
}
```
================================================
FILE: models/InstructDiffusion/configs/instruct_diffusion.yaml
================================================
# File modified by authors of InstructDiffusion from original (https://github.com/CompVis/stable-diffusion).
# See more details in LICENSE.
model:
base_learning_rate: 1.0e-04
weight_decay: 0.01
target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
params:
fp16: True
deepspeed: 'deepspeed_1'
ckpt_path: data/checkpoints/v1-5-pruned-emaonly-adaption.ckpt
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: edited
cond_stage_key: edit
image_size: 32
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 0 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
force_type_convert: True
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 64
num_workers: 4
train:
- ds1:
target: dataset.pose.pose.MPIIDataset
params:
root: data/mpii/
image_set: train
is_train: True
max_prompt_num: 5
min_prompt_num: 1
radius: 10
- ds2:
target: dataset.pose.pose.COCODataset
params:
root: data/coco/
image_set: train2017
is_train: True
max_prompt_num: 5
min_prompt_num: 1
radius: 10
- ds3:
target: dataset.pose.pose.CrowdPoseDataset
params:
root: data/crowdpose/
image_set: train
is_train: True
max_prompt_num: 5
min_prompt_num: 1
radius: 10
- ds4:
target: dataset.pose.pose.AICDataset
params:
root: data/aic/
image_set: train
is_train: True
max_prompt_num: 5
min_prompt_num: 1
radius: 10
sample_weight: 0.1
- ds5:
target: dataset.seg.coco_stuff.COCOStuffDataset
params:
path: data/coco-stuff
split: train2017
crop_res: 256
flip_prob: 0.5
transparency: 0.5
empty_percentage: 0.2
- ds6:
target: dataset.seg.grefcoco_segmentation.GrefCOCODataset
params:
path: data/coco_2014
split: train
min_resize_res: 256
max_resize_res: 256
crop_res: 256
flip_prob: 0.0
transparency: 0.5
- ds7:
target: dataset.seg.refcoco_segmentation.RefCOCODataset
params:
path: data/coco_2014
split: train
crop_res: 256
flip_prob: 0.0
transparency: 0.5
- ds8:
target: dataset.low_level.lowlevel_gopro.GoPro
params:
path: data/GoPro
split: train
size: 256
flip_prob: 0.5
interpolation: pil_lanczos
sample_weight: 2.0
- ds9:
target: dataset.low_level.lowlevel_reds.REDS
params:
path: data/REDS
split: train
size: 256
flip_prob: 0.5
interpolation: pil_lanczos
sample_weight: 0.2
- ds10:
target: dataset.low_level.lowlevel_sidd.SIDD
params:
path: data/SIDD
split: train
size: 256
flip_prob: 0.5
interpolation: pil_lanczos
sample_weight: 20
- ds11:
target: dataset.low_level.lowlevel_clwd.CLWD
params:
path: data/CLWD
split: train
size: 256
flip_prob: 0.5
interpolation: pil_lanczos
sample_weight: 0.2
- ds12:
target: dataset.editing.edit_zip_dataset.FilteredIP2PDataset
params:
path: data/clip-filtered-dataset
split: train
min_resize_res: 256
max_resize_res: 256
crop_res: 256
flip_prob: 0.5
sample_weight: 0.2
- ds13:
target: dataset.editing.edit_zip_dataset.GIERDataset
params:
path: data/GIER_editing_data/
split: train
min_resize_res: 256
max_resize_res: 256
crop_res: 256
flip_prob: 0.0
zip_start_index: 0
zip_end_index: 100
sample_weight: 2.0
- ds14:
target: dataset.editing.edit_zip_dataset.GQAInpaintDataset
params:
path: data/gqa-inpaint
min_resize_res: 256
max_resize_res: 256
crop_res: 256
flip_prob: 0.0
- ds15:
target: dataset.editing.edit_zip_dataset.MagicBrushDataset
params:
path: data/MagicBrush/
split: train
min_resize_res: 256
max_resize_res: 256
crop_res: 256
flip_prob: 0.5
zip_start_index: 0
zip_end_index: 100
- ds16:
target: dataset.editing.edit_zip_dataset.IEIWDataset
params:
path: data/ieiw/
split: train
min_resize_res: 256
max_resize_res: 256
crop_res: 256
flip_prob: 0.5
validation:
target: dataset.pose.pose.COCODataset
params:
root: data/coco/
image_set: val2017
is_train: False
max_prompt_num: 5
min_prompt_num: 1
radius: 10
trainer:
initial_scale: 13
max_epochs: 200
save_freq: 5
accumulate_grad_batches: 1
clip_grad: 0.0
optimizer: adamw
================================================
FILE: models/InstructDiffusion/dataset/README.md
================================================
You can download these datasets: [COCO](http://cocodataset.org/#download), [CrowdPose](https://github.com/Jeff-sjtu/CrowdPose#dataset), [MPII](http://human-pose.mpi-inf.mpg.de/), [AIC](https://arxiv.org/abs/1711.06475), [COCO-Stuff](https://github.com/nightrome/cocostuff), [RefCOCO](https://github.com/lichengunc/refer), [GrefCOCO](https://github.com/henghuiding/gRefCOCO), [GoPro](https://seungjunnah.github.io/Datasets/gopro), [REDS](https://seungjunnah.github.io/Datasets/reds.html), [SIDD](https://www.eecs.yorku.ca/~kamel/sidd/), [CLWD](https://arxiv.org/abs/2012.07616), [IP2PDataset](https://github.com/timothybrooks/instruct-pix2pix), [GIER](https://sites.google.com/view/gierdataset), [GQAInpaint](https://github.com/abyildirim/inst-inpaint), [MagicBrush](https://osu-nlp-group.github.io/MagicBrush/). The resulting data directory should look like this:
InstructDiffusion
|-- data
`-- |-- coco
| |-- annotations
| `-- images
|-- mpii
| |-- annot
| `-- images
|-- crowdpose
| |-- json
| `-- images
|-- aic
| |-- annotations
| `-- ai_challenger_keypoint_train_20170902
|
|-- coco-stuff
| |-- annotations
| |-- labels.txt
| `-- images
|-- coco_2014
| |-- grefcoco
| | |-- grefs(unc).json
| | `-- instances.json
| |-- refcoco
| | |-- instances.json
| | |-- refs(google).p
| | `-- refs(unc).p
| `-- images
|
|-- GoPro
| |-- train
| `-- test
|-- REDS
| |-- train
| `-- val
|-- SIDD
| |-- train
| `-- val
|-- CLWD
| |-- train
| |-- test
| `-- watermark_logo
|
|-- clip-filtered-dataset
| |-- shard-00.zip
| |-- shard-01.zip
| `-- ...
|-- GIER_editing_data
| |-- images
| `-- GIER.json
|-- gqa-inpaint
| |-- images
| |-- images_inpainted
| |-- masks
| |-- train_scenes.json
| `-- meta_info.json
`-- MagicBrush
|-- data
|-- processed-train
`-- magic_train.json
================================================
FILE: models/InstructDiffusion/dataset/editing/edit_zip_dataset.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Tiankai Hang (tkhang@seu.edu.cn)
# --------------------------------------------------------
from __future__ import annotations
import os
import json
import math
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torchvision
from einops import rearrange
import PIL
from PIL import Image
from torch.utils.data import Dataset
import random
from dataset.utils.zip_manager import MultipleZipManager
if hasattr(Image, "Resampling"):
# deprecated in pillow >= 10.0.0
RESAMPLING_METHOD = Image.Resampling.LANCZOS
else:
RESAMPLING_METHOD = Image.LANCZOS
class FilteredIP2PDataset(Dataset):
def __init__(
self,
path: str,
split: str = "train",
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
min_resize_res: int = 256,
max_resize_res: int = 256,
crop_res: int = 256,
flip_prob: float = 0.0,
zip_start_index: int = 0,
zip_end_index: int = 30,
instruct: bool = False,
max_num_images = None,
sample_weight: float = 1.0,
reverse_version: bool = False,
**kwargs
):
assert split in ("train", "val", "test")
assert sum(splits) == 1
self.path = path
self.min_resize_res = min_resize_res
self.max_resize_res = max_resize_res
self.crop_res = crop_res
self.flip_prob = flip_prob
self.instruct = instruct
zip_list = []
for i in range(zip_start_index, zip_end_index):
name = "shard-"+str(i).zfill(2)+'.zip'
zip_list.append(os.path.join(self.path, name))
self.image_dataset = MultipleZipManager(zip_list, 'image', sync=True) # sync=True is faster
with open(Path(self.path, "seeds.json")) as f:
self.seeds = json.load(f)
split_0, split_1 = {
"train": (0.0, splits[0]),
"val": (splits[0], splits[0] + splits[1]),
"test": (splits[0] + splits[1], 1.0),
}[split]
idx_0 = math.floor(split_0 * len(self.seeds))
idx_1 = math.floor(split_1 * len(self.seeds))
self.seeds = self.seeds[idx_0:idx_1]
if max_num_images is not None and max_num_images > 0:
self.seeds = self.seeds[:min(max_num_images, len(self.seeds))]
# flatten seeds
self.seeds = [(name, seed) for name, seeds in self.seeds for seed in seeds]
self.sample_weight = sample_weight
while True:
try:
with open('filtered_ids_ip2p.json') as json_file:
filtered_ids = json.load(json_file)
break
except:
# download json file from url
if reverse_version:
os.system('wget https://github.com/TiankaiHang/storage/releases/download/readout/filtered_ids_ip2p.json')
else:
os.system("wget https://github.com/TiankaiHang/storage/releases/download/readout/filtered-ip2p-thres5.5-0.5.json -O filtered_ids_ip2p.json")
print("seeds:", len(self.seeds))
# self.seeds = [seed for seed in self.seeds if seed[1] in filtered_ids]
# faster
# self.seeds = list(filter(lambda seed: seed[1] in filtered_ids, self.seeds))
# to numpy and faster in parallel
# import pdb; pdb.set_trace()
_seeds = [f"{a}/{b}" for a, b in self.seeds]
self.seeds = np.array(self.seeds)
_seeds = np.array(_seeds)
self.seeds = self.seeds[np.isin(_seeds, filtered_ids)]
self.seeds = self.seeds.tolist()
self.return_add_kwargs = kwargs.get("return_add_kwargs", False)
def __len__(self) -> int:
return int(len(self.seeds) * self.sample_weight)
def __getitem__(self, i: int) -> dict[str, Any]:
# name, seeds = self.seeds[i]
if self.sample_weight >= 1:
i = i % len(self.seeds)
else:
remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
name, seed = self.seeds[i]
propt_name = name + "/prompt.json"
if not self.image_dataset.managers[self.image_dataset.mapping[propt_name]]._init:
self.image_dataset.managers[self.image_dataset.mapping[propt_name]].initialize(close=False)
# propt_name = name + "/prompt.json"
byteflow = self.image_dataset.managers[self.image_dataset.mapping[propt_name]].zip_fd.read(propt_name)
texts = json.loads(byteflow.decode('utf-8'))
prompt = texts["edit"]
if self.instruct:
prompt = "Image Editing: " + prompt
text_input = texts["input"]
text_output = texts["output"]
# image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
# image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
image_0 = self.image_dataset.get(name+f"/{seed}_0.jpg")
image_1 = self.image_dataset.get(name+f"/{seed}_1.jpg")
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
if self.return_add_kwargs:
add_kwargs = dict(
name=name,
seed=seed,
text_input=text_input,
text_output=text_output,
)
else:
add_kwargs = {}
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt), **add_kwargs)
class GIERDataset(Dataset):
def __init__(
self,
path: str,
split: str = "train",
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
min_resize_res: int = 256,
max_resize_res: int = 256,
crop_res: int = 256,
flip_prob: float = 0.0,
zip_start_index: int = 0,
zip_end_index: int = 30,
sample_weight: float = 1.0,
instruct: bool = False,
):
assert split in ("train", "val", "test")
assert sum(splits) == 1
self.path = path
self.min_resize_res = min_resize_res
self.max_resize_res = max_resize_res
self.crop_res = crop_res
self.flip_prob = flip_prob
self.instruct = instruct
# self.meta = torch.load(Path(self.path, "GIER.json"), map_location="cpu")
# load json file
with open(Path(self.path, "GIER_new.json")) as json_file:
self.meta = json.load(json_file)
print(f"||||||||||||||||||||||||||||| \n Loaded {len(self.meta)} images from json file")
input_does_not_exist = []
output_does_not_exist = []
# filter out out images that do not exist
if not os.path.exists(os.path.join(self.path, "filtered_meta_new.pt")):
filtered_meta = []
for i in range(len(self.meta)):
input_path = os.path.join(self.path, "warped", self.meta[i]["input"])
output_path = os.path.join(self.path, "warped", self.meta[i]["output"])
if not os.path.exists(input_path):
input_path = os.path.join(self.path, "images", self.meta[i]["input"])
if not os.path.exists(input_path):
input_does_not_exist.append(input_path)
if not os.path.exists(output_path):
output_path = os.path.join(self.path, "images", self.meta[i]["output"])
if not os.path.exists(output_path):
output_does_not_exist.append(output_path)
if os.path.exists(input_path) and os.path.exists(output_path):
filtered_meta.append(
dict(
input=input_path,
output=output_path,
prompts=self.meta[i]["prompts"],
)
)
else:
print(f"\n {input_path} or {output_path} does not exist")
torch.save(filtered_meta, os.path.join(self.path, "filtered_meta_new.pt"))
else:
filtered_meta = torch.load(os.path.join(self.path, "filtered_meta_new.pt"), map_location="cpu")
self.meta = filtered_meta
print(f"||||||||||||||||||||||||||||| \n Filtered {len(self.meta)} images")
for i in range(len(self.meta)):
self.meta[i]['input'] = self.meta[i]['input'].replace('/mnt/external/datasets/GIER_editing_data/', self.path)
self.meta[i]['output'] = self.meta[i]['output'].replace('/mnt/external/datasets/GIER_editing_data/', self.path)
# write input_does_not_exist and output_does_not_exist to file
with open(Path(self.path, f"input_does_not_exist.txt"), "w") as f:
for item in input_does_not_exist:
f.write("%s\n" % item)
with open(Path(self.path, f"output_does_not_exist.txt"), "w") as f:
for item in output_does_not_exist:
f.write("%s\n" % item)
split_0, split_1 = {
"train": (0.0, splits[0]),
"val": (splits[0], splits[0] + splits[1]),
"test": (splits[0] + splits[1], 1.0),
}[split]
idx_0 = math.floor(split_0 * len(self.meta))
idx_1 = math.floor(split_1 * len(self.meta))
self.meta = self.meta[idx_0:idx_1]
self.sample_weight = sample_weight
print('original GIER', len(self.meta))
def __len__(self) -> int:
return int(len(self.meta) * self.sample_weight)
def __getitem__(self, i: int) -> dict[str, Any]:
if self.sample_weight >= 1:
i = i % len(self.meta)
else:
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
# prompt = self.meta[i]["prompts"]
prompt = random.choice(self.meta[i]["prompts"])
try:
image_0 = Image.open(self.meta[i]["input"]).convert("RGB")
image_1 = Image.open(self.meta[i]["output"]).convert("RGB")
except PIL.UnidentifiedImageError:
print(f"\n {self.meta[i]['input']} or {self.meta[i]['output']} is not a valid image")
i = random.randint(0, len(self.meta) - 1)
return self.__getitem__(i)
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
if self.instruct:
prompt = "Image Editing: " + prompt
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
class GQAInpaintDataset(Dataset):
r"""
shoud download and unzip the data first
```
mkdir -p ../datasets
cd ../datasets
# if file exists, then skip
if [ ! -f "gqa-inpaint.zip" ]; then
sudo azcopy copy "https://bingdatawu2.blob.core.windows.net/genrecog/private/t-thang/gqa-inpaint.zip${TOKEN}" .
unzip gqa-inpaint.zip -d gqa-inpaint > /dev/null
fi
if [ ! -f "images.zip" ]; then
sudo azcopy copy "https://bingdatawu2.blob.core.windows.net/genrecog/private/t-thang/images.zip${TOKEN}" .
unzip images.zip > /dev/null
fi
```
"""
def __init__(self, **kwargs):
# load from json ../datasets/gqa-inpaint/meta_info.json
self.path = kwargs.get("path", "../datasets/gqa-inpaint")
self.instruct = kwargs.get("instruct", False)
with open(self.path + "/meta_info.json", "r") as f:
self.meta_info = json.load(f)
self.min_resize_res = kwargs.get("min_resize_res", 256)
self.max_resize_res = kwargs.get("max_resize_res", 256)
self.crop_res = kwargs.get("crop_res", 256)
self.flip_prob = kwargs.get("flip_prob", 0.5)
def __len__(self):
return len(self.meta_info)
def __getitem__(self, i):
item = self.meta_info[i]
src_img = Image.open(item["source_image_path"].replace("../datasets", self.path)).convert("RGB")
tgt_img = Image.open(item["target_image_path"].replace("../datasets/gqa-inpaint", self.path)).convert("RGB")
image_0 = src_img
image_1 = tgt_img
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD)
instruction = item["instruction"]
if self.instruct:
instruction = "Image Editing: " + instruction
# return image_0, image_1, instruction
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=instruction))
class MagicBrushDataset(Dataset):
def __init__(
self,
path: str,
split: str = "train",
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
min_resize_res: int = 256,
max_resize_res: int = 256,
crop_res: int = 256,
flip_prob: float = 0.0,
zip_start_index: int = 0,
zip_end_index: int = 30,
len_dataset: int = -1,
instruct: bool = False,
sample_weight: float = 1.0,
):
assert split in ("train", "val", "test")
assert sum(splits) == 1
self.path = path
self.min_resize_res = min_resize_res
self.max_resize_res = max_resize_res
self.crop_res = crop_res
self.flip_prob = flip_prob
self.instruct = instruct
self.sample_weight = sample_weight
self.meta_path = os.path.join(self.path, "magic_train.json")
with open(self.meta_path, "r") as f:
self.meta = json.load(f)
def __len__(self) -> int:
return int(len(self.meta) * self.sample_weight)
def __getitem__(self, i: int) -> dict[str, Any]:
if self.sample_weight >= 1:
i = i % len(self.meta)
else:
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
item = self.meta[i]
try:
image_0 = Image.open(os.path.join(self.path, item["input"])).convert("RGB")
image_1 = Image.open(os.path.join(self.path, item["edited"])).convert("RGB")
except (PIL.UnidentifiedImageError, FileNotFoundError):
print(f"\n {self.path}/{item['input']} or {self.path}/{item['edited']} is not a valid image")
i = random.randint(0, len(self.meta) - 1)
return self.__getitem__(i)
prompt = item["instruction"]
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD)
if self.instruct:
prompt = "Image Editing: " + prompt
# return image_0, image_1, prompt
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
class IEIWDataset(Dataset):
def __init__(
self,
path: str,
split: str = "train",
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
min_resize_res: int = 256,
max_resize_res: int = 256,
crop_res: int = 256,
flip_prob: float = 0.0,
zip_start_index: int = 0,
zip_end_index: int = 30,
sample_weight: float = 1.0,
instruct: bool = False,
):
assert split in ("train", "val", "test")
assert sum(splits) == 1
self.path = path
self.min_resize_res = min_resize_res
self.max_resize_res = max_resize_res
self.crop_res = crop_res
self.flip_prob = flip_prob
self.instruct = instruct
self.meta_path = os.path.join(self.path, "meta_infov1.json")
with open(self.meta_path, "r") as f:
self.meta = json.load(f)
self.sample_weight = sample_weight
print('original synthetic', len(self.meta))
def __len__(self) -> int:
return int(len(self.meta) * self.sample_weight)
def __getitem__(self, i: int) -> dict[str, Any]:
if self.sample_weight >= 1:
i = i % len(self.meta)
else:
i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
item = self.meta[i]
item['input'] = item['input'].replace('/mnt/external/tmp/2023/06/11/', self.path)
item['edited'] = item['edited'].replace('/mnt/external/tmp/2023/06/11/', self.path)
try:
image_0 = Image.open(item["input"]).convert("RGB")
image_1 = Image.open(item["edited"]).convert("RGB")
except (PIL.UnidentifiedImageError, FileNotFoundError):
print(f"\n {item['input']} or {item['edited']} is not a valid image")
i = random.randint(0, len(self.meta) - 1)
return self.__getitem__(i)
prompt = item["instruction"]
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), RESAMPLING_METHOD)
image_1 = image_1.resize((reize_res, reize_res), RESAMPLING_METHOD)
if self.instruct:
prompt = "Image Editing: " + prompt
# return image_0, image_1, prompt
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_clwd.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Chen Li (edward82@stu.xjtu.edu.cn)
# --------------------------------------------------------
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
import cv2
from PIL import Image
import torchvision
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class CLWD(Dataset):
def __init__(self, path, split="train", size=256, interpolation="pil_lanczos",
flip_prob=0.5, sample_weight=1.0, instruct=False):
super(CLWD, self).__init__()
inp_files = sorted(os.listdir(os.path.join(path, split, 'Watermarked_image')))
tar_files = sorted(os.listdir(os.path.join(path, split, 'Watermark_free_image')))
self.inp_filenames = [os.path.join(path, split, 'Watermarked_image', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(path, split, 'Watermark_free_image', x) for x in tar_files if is_image_file(x)]
self.size = size
self.flip_prob = flip_prob
self.sample_weight = sample_weight
self.instruct = instruct
self.sizex = len(self.tar_filenames) # get the size of target
self.interpolation = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": Image.NEAREST,
"pil_bilinear": Image.BILINEAR,
"pil_bicubic": Image.BICUBIC,
"pil_box": Image.BOX,
"pil_hamming": Image.HAMMING,
"pil_lanczos": Image.LANCZOS,
}[interpolation]
prompt_path='dataset/prompt/prompt_dewatermark.txt'
self.prompt_list=[]
with open(prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.prompt_list.append(line)
line=f.readline()
print(f"CLWD has {len(self)} samples!!")
def __len__(self):
return int(self.sizex * self.sample_weight)
def __getitem__(self, index):
if self.sample_weight >= 1:
index_ = index % self.sizex
else:
index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
width, height = inp_img.size
tar_width, tar_height = tar_img.size
assert tar_width == width and tar_height == height, "Input and target image mismatch"
aspect_ratio = float(width) / float(height)
if width < height:
new_width = self.size
new_height = int(self.size / aspect_ratio)
else:
new_height = self.size
new_width = int(self.size * aspect_ratio)
inp_img = inp_img.resize((new_width, new_height), self.interpolation)
tar_img = tar_img.resize((new_width, new_height), self.interpolation)
inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
crop = torchvision.transforms.RandomCrop(self.size)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
prompt = random.choice(self.prompt_list)
if self.instruct:
prompt = "Watermark Removal: " + prompt
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_gopro.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Chen Li (edward82@stu.xjtu.edu.cn)
# --------------------------------------------------------
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
import cv2
from PIL import Image
import torchvision
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class GoPro(Dataset):
def __init__(self, path, split="train", size=256, interpolation="pil_lanczos",
flip_prob=0.5, sample_weight=1.0, instruct=False):
super(GoPro, self).__init__()
inp_files = sorted(os.listdir(os.path.join(path, split, 'input')))
tar_files = sorted(os.listdir(os.path.join(path, split, 'target')))
self.inp_filenames = [os.path.join(path, split, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(path, split, 'target', x) for x in tar_files if is_image_file(x)]
self.size = size
self.flip_prob = flip_prob
self.sample_weight = sample_weight
self.instruct = instruct
self.sizex = len(self.tar_filenames) # get the size of target
self.interpolation = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": Image.NEAREST,
"pil_bilinear": Image.BILINEAR,
"pil_bicubic": Image.BICUBIC,
"pil_box": Image.BOX,
"pil_hamming": Image.HAMMING,
"pil_lanczos": Image.LANCZOS,
}[interpolation]
prompt_path='dataset/prompt/prompt_deblur.txt'
self.prompt_list=[]
with open(prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.prompt_list.append(line)
line=f.readline()
print(f"GoPro has {len(self)} samples!!")
def __len__(self):
return int(self.sizex * self.sample_weight)
def __getitem__(self, index):
if self.sample_weight >= 1:
index_ = index % self.sizex
else:
index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
width, height = inp_img.size
tar_width, tar_height = tar_img.size
assert tar_width == width and tar_height == height, "Input and target image mismatch"
aspect_ratio = float(width) / float(height)
if width < height:
new_width = self.size
new_height = int(self.size / aspect_ratio)
else:
new_height = self.size
new_width = int(self.size * aspect_ratio)
inp_img = inp_img.resize((new_width, new_height), self.interpolation)
tar_img = tar_img.resize((new_width, new_height), self.interpolation)
inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
crop = torchvision.transforms.RandomCrop(self.size)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
prompt = random.choice(self.prompt_list)
if self.instruct:
prompt = "Image Deblurring: " + prompt
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_reds.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Chen Li (edward82@stu.xjtu.edu.cn)
# --------------------------------------------------------
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
import cv2
from PIL import Image
import torchvision
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class REDS(Dataset):
def __init__(self, path, split="train", size=256, interpolation="pil_lanczos",
flip_prob=0.5, sample_weight=1.0, instruct=False):
super(REDS, self).__init__()
inp_files = sorted(os.listdir(os.path.join(path, split, 'blur')))
tar_files = sorted(os.listdir(os.path.join(path, split, 'sharp')))
if split == "train":
self.inp_filenames = [os.path.join(path, split, 'blur', d, x) for d in inp_files for x in sorted(os.listdir(os.path.join(path, split, 'blur', d))) if is_image_file(x)]
self.tar_filenames = [os.path.join(path, split, 'sharp', d, x) for d in tar_files for x in sorted(os.listdir(os.path.join(path, split, 'sharp', d))) if is_image_file(x)]
else:
self.inp_filenames = [os.path.join(path, split, 'blur', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(path, split, 'sharp', x) for x in tar_files if is_image_file(x)]
self.size = size
self.flip_prob = flip_prob
self.sample_weight = sample_weight
self.instruct = instruct
assert len(self.inp_filenames) == len(self.tar_filenames)
self.sizex = len(self.tar_filenames) # get the size of target
self.interpolation = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": Image.NEAREST,
"pil_bilinear": Image.BILINEAR,
"pil_bicubic": Image.BICUBIC,
"pil_box": Image.BOX,
"pil_hamming": Image.HAMMING,
"pil_lanczos": Image.LANCZOS,
}[interpolation]
prompt_path='dataset/prompt/prompt_deblur.txt'
self.prompt_list=[]
with open(prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.prompt_list.append(line)
line=f.readline()
print(f"REDS has {len(self)} samples!!")
def __len__(self):
return int(self.sizex * self.sample_weight)
def __getitem__(self, index):
if self.sample_weight >= 1:
index_ = index % self.sizex
else:
index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
width, height = inp_img.size
tar_width, tar_height = tar_img.size
assert tar_width == width and tar_height == height, "Input and target image mismatch"
aspect_ratio = float(width) / float(height)
if width < height:
new_width = self.size
new_height = int(self.size / aspect_ratio)
else:
new_height = self.size
new_width = int(self.size * aspect_ratio)
inp_img = inp_img.resize((new_width, new_height), self.interpolation)
tar_img = tar_img.resize((new_width, new_height), self.interpolation)
inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
crop = torchvision.transforms.RandomCrop(self.size)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
prompt = random.choice(self.prompt_list)
if self.instruct:
prompt = "Image Deblurring: " + prompt
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_sidd.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Chen Li (edward82@stu.xjtu.edu.cn)
# --------------------------------------------------------
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
import cv2
from PIL import Image
import torchvision
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class SIDD(Dataset):
def __init__(self, path, split="train", size=256, interpolation="pil_lanczos",
flip_prob=0.5, sample_weight=1.0, instruct=False):
super(SIDD, self).__init__()
inp_files = sorted(os.listdir(os.path.join(path, split, 'input')))
tar_files = sorted(os.listdir(os.path.join(path, split, 'gt')))
self.inp_filenames = [os.path.join(path, split, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(path, split, 'gt', x) for x in tar_files if is_image_file(x)]
self.size = size
self.flip_prob = flip_prob
self.sample_weight = sample_weight
self.instruct = instruct
self.sizex = len(self.tar_filenames) # get the size of target
self.interpolation = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": Image.NEAREST,
"pil_bilinear": Image.BILINEAR,
"pil_bicubic": Image.BICUBIC,
"pil_box": Image.BOX,
"pil_hamming": Image.HAMMING,
"pil_lanczos": Image.LANCZOS,
}[interpolation]
prompt_path='dataset/prompt/prompt_denoise.txt'
self.prompt_list=[]
with open(prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.prompt_list.append(line)
line=f.readline()
print(f"SIDD has {len(self)} samples!!")
def __len__(self):
return int(self.sizex * self.sample_weight)
def __getitem__(self, index):
if self.sample_weight >= 1:
index_ = index % self.sizex
else:
index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
width, height = inp_img.size
tar_width, tar_height = tar_img.size
assert tar_width == width and tar_height == height, "Input and target image mismatch"
inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
crop = torchvision.transforms.RandomCrop(self.size)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
prompt = random.choice(self.prompt_list)
if self.instruct:
prompt = "Image Denoising: " + prompt
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/pose/pose.py
================================================
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Zigang Geng (zigang@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import annotations
import logging
import os
import json
import copy
import math
import random
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from pycocotools.coco import COCO
logger = logging.getLogger(__name__)
colors = {
'red': (255, 0, 0),
'green': (0, 255, 0),
'blue': (0, 0, 255),
'yellow': (255, 255, 0),
'cyan': (0, 255, 255),
'magenta': (255, 0, 255),
'gray': (128, 128, 128),
'white': (255, 255, 255),
'black': (0, 0, 0)}
def readTXT(txt_path):
with open(txt_path, 'r') as f:
listInTXT = [line.strip() for line in f]
return listInTXT
class PoseDataset(Dataset):
def __init__(self, root, image_set, is_train, max_prompt_num=5, min_prompt_num=1,
radius=10, size=256, transparency=0.0, sample_weight=1.0, transform=None):
self.sample_weight = sample_weight
self.max_prompt_num = max_prompt_num
self.min_prompt_num = min_prompt_num
self.radius = radius
self.transparency = transparency
self.num_joints = 0
self.pixel_std = 200
self.flip_pairs = []
self.parent_ids = []
self.keypoints_type = {}
self.is_train = is_train
self.image_set = image_set
self.root = root
self.scale_factor = 0.35
self.rotation_factor = 45
self.flip = True
self.num_joints_half_body = 8
self.prob_half_body = 0.3
self.image_size = np.array((size, size))
self.heatmap_size = np.array((size, size))
self.transform = transform
self.db = []
pose_diverse_prompt_path = 'dataset/prompt/prompt_pose.txt'
self.pose_diverse_prompt_list = []
with open(pose_diverse_prompt_path) as f:
line = f.readline()
while line:
line = line.strip('\n')
self.pose_diverse_prompt_list.append(line)
line = f.readline()
def _get_db(self):
raise NotImplementedError
def evaluate(self, preds, output_dir, *args, **kwargs):
raise NotImplementedError
def half_body_transform(self, joints, joints_vis):
upper_joints = []
lower_joints = []
for joint_id in range(self.num_joints):
if joints_vis[joint_id][0] > 0:
if joint_id in self.upper_body_ids:
upper_joints.append(joints[joint_id])
else:
lower_joints.append(joints[joint_id])
if np.random.randn() < 0.5 and len(upper_joints) > 2:
selected_joints = upper_joints
else:
selected_joints = lower_joints \
if len(lower_joints) > 2 else upper_joints
if len(selected_joints) < 2:
return None, None
selected_joints = np.array(selected_joints, dtype=np.float32)
center = selected_joints.mean(axis=0)[:2]
left_top = np.amin(selected_joints, axis=0)
right_bottom = np.amax(selected_joints, axis=0)
w = right_bottom[0] - left_top[0]
h = right_bottom[1] - left_top[1]
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array(
[
w * 1.0 / self.pixel_std,
h * 1.0 / self.pixel_std
],
dtype=np.float32
)
scale = scale * 1.5
return center, scale
def __len__(self,):
return int(len(self.db) * self.sample_weight)
def __getitem__(self, idx):
if self.sample_weight >= 1:
idx = idx % len(self.db)
else:
idx = int(idx / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
db_rec = copy.deepcopy(self.db[idx])
image_file = db_rec['image']
filename = db_rec['filename'] if 'filename' in db_rec else ''
imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
)
data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
if data_numpy is None:
logger.error('=> fail to read {}'.format(image_file))
raise ValueError('Fail to read {}'.format(image_file))
joints = db_rec['joints_3d']
joints_vis = db_rec['joints_3d_vis']
c = db_rec['center']
s = db_rec['scale']
score = db_rec['score'] if 'score' in db_rec else 1
r = 0
if self.is_train:
if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body
and np.random.rand() < self.prob_half_body):
c_half_body, s_half_body = self.half_body_transform(
joints, joints_vis
)
if c_half_body is not None and s_half_body is not None:
c, s = c_half_body, s_half_body
sf = self.scale_factor
rf = self.rotation_factor
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
if random.random() <= 0.6 else 0
if self.flip and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, self.image_size)
input = cv2.warpAffine(
data_numpy,
trans,
(int(self.image_size[0]), int(self.image_size[1])),
flags=cv2.INTER_LINEAR)
if self.transform:
input = self.transform(input)
for i in range(self.num_joints):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
target, prompt = self.generate_target(input, joints, joints_vis)
# return Image.fromarray(input), Image.fromarray(target), prompt
image_0 = rearrange(2 * torch.tensor(np.array(input)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(target)).float() / 255 - 1, "h w c -> c h w")
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
def generate_target(self, input, joints, joints_vis):
'''
:param input: [height, width, 3]
:param joints: [num_joints, 3]
:param joints_vis: [num_joints, 3]
:return: target
'''
radius = self.radius
target = copy.deepcopy(input)
joint_num = random.randint(self.min_prompt_num, self.max_prompt_num)
joint_ids = np.random.choice([i for i in range(self.num_joints)], joint_num, replace=False)
random_color_names = random.sample(list(colors.keys()), len(joint_ids))
random_marker_names = ['circle' for i in range(len(joint_ids))]
prompt = ""
for color_idx, joint_id in enumerate(joint_ids):
feat_stride = self.image_size / self.heatmap_size
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - radius), int(mu_y - radius)]
br = [int(mu_x + radius + 1), int(mu_y + radius + 1)]
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
or br[0] < 0 or br[1] < 0:
# If not, just return the image as is
joints_vis[joint_id][0] = 0
continue
marker_size = 2 * radius + 1
g = np.zeros((marker_size, marker_size))
x, y = np.indices((marker_size, marker_size))
interval = int((marker_size - marker_size / math.sqrt(2)) // 2)
mask = (x - radius) ** 2 + (y - radius) ** 2 <= radius ** 2 + 1
g[mask] = 1
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
v = joints_vis[joint_id][0]
random_color_name = random_color_names[color_idx]
random_color = colors[random_color_name]
prompt += random.choice(self.pose_diverse_prompt_list).format(
color=random_color_name,
joint=self.keypoints_type[joint_id])
if v > 0.5:
target[img_y[0]:img_y[1], img_x[0]:img_x[1]][g[g_y[0]:g_y[1], g_x[0]:g_x[1]]>0] \
= self.transparency*target[img_y[0]:img_y[1], img_x[0]:img_x[1]][g[g_y[0]:g_y[1], g_x[0]:g_x[1]]>0] \
+ (1-self.transparency)*np.array(random_color)
return target, prompt
class COCODataset(PoseDataset):
def __init__(self, root, image_set, is_train, max_prompt_num=5, min_prompt_num=1,
radius=10, size=256, transparency=0.0, sample_weight=1.0, transform=None):
super().__init__(root, image_set, is_train, max_prompt_num, min_prompt_num,
radius, size, transparency, sample_weight, transform)
self.keypoints_type = {
0: "nose",
1: "left eye",
2: "right eye",
3: "left ear",
4: "right ear",
5: "left shoulder",
6: "right shoulder",
7: "left elbow",
8: "right elbow",
9: "left wrist",
10: "right wrist",
11: "left hip",
12: "right hip",
13: "left knee",
14: "right knee",
15: "left ankle",
16: "right ankle"
}
self.image_width = size
self.image_height = size
self.aspect_ratio = self.image_width * 1.0 / self.image_height
self.pixel_std = 200
self.coco = COCO(self._get_ann_file_keypoint())
# deal with class names
cats = [cat['name']
for cat in self.coco.loadCats(self.coco.getCatIds())]
self.classes = ['__background__'] + cats
logger.info('=> classes: {}'.format(self.classes))
self.num_classes = len(self.classes)
self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
self._coco_ind_to_class_ind = dict(
[
(self._class_to_coco_ind[cls], self._class_to_ind[cls])
for cls in self.classes[1:]
]
)
# load image file names
self.image_set_index = self._load_image_set_index()
self.num_images = len(self.image_set_index)
logger.info('=> num_images: {}'.format(self.num_images))
self.num_joints = 17
self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]]
self.parent_ids = None
self.upper_body_ids = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
self.lower_body_ids = (11, 12, 13, 14, 15, 16)
if 'coco' in self.root:
self.db = self._get_db()
logger.info('=> load {} samples'.format(len(self.db)))
def _get_ann_file_keypoint(self):
""" self.root / annotations / person_keypoints_train2017.json """
if 'coco' in self.root:
prefix = 'person_keypoints' \
if 'test' not in self.image_set else 'image_info'
return os.path.join(
self.root,
'annotations',
prefix + '_' + self.image_set + '.json'
)
elif 'crowdpose' in self.root:
prefix = 'crowdpose'
return os.path.join(
self.root,
'json',
prefix + '_' + self.image_set + '.json'
)
elif 'aic' in self.root:
prefix = 'aic'
return os.path.join(
self.root,
'annotations',
prefix + '_' + self.image_set + '.json'
)
else:
raise ValueError('Please write the path for this new dataset.')
def _load_image_set_index(self):
""" image id: int """
image_ids = self.coco.getImgIds()
return image_ids
def _get_db(self):
gt_db = self._load_coco_keypoint_annotations()
return gt_db
def _load_coco_keypoint_annotations(self):
""" ground truth bbox and keypoints """
gt_db = []
for index in self.image_set_index:
gt_db.extend(self._load_coco_keypoint_annotation_kernal(index))
return gt_db
def _load_coco_keypoint_annotation_kernal(self, index):
"""
coco ann: [u'segmentation', u'area', u'iscrowd', u'image_id', u'bbox', u'category_id', u'id']
iscrowd:
crowd instances are handled by marking their overlaps with all categories to -1
and later excluded in training
bbox:
[x1, y1, w, h]
:param index: coco image id
:return: db entry
"""
im_ann = self.coco.loadImgs(index)[0]
width = im_ann['width']
height = im_ann['height']
annIds = self.coco.getAnnIds(imgIds=index, iscrowd=False)
objs = self.coco.loadAnns(annIds)
# sanitize bboxes
valid_objs = []
for obj in objs:
x, y, w, h = obj['bbox']
x1 = np.max((0, x))
y1 = np.max((0, y))
x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
if 'crowdpose' in self.root:
obj['area'] = 1
if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
obj['clean_bbox'] = [x1, y1, x2-x1, y2-y1]
valid_objs.append(obj)
objs = valid_objs
rec = []
for obj in objs:
cls = self._coco_ind_to_class_ind[obj['category_id']]
if cls != 1:
continue
# ignore objs without keypoints annotation
if max(obj['keypoints']) == 0:
continue
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float32)
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float32)
for ipt in range(self.num_joints):
joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
joints_3d[ipt, 2] = 0
t_vis = obj['keypoints'][ipt * 3 + 2]
if t_vis > 1:
t_vis = 1
joints_3d_vis[ipt, 0] = t_vis
joints_3d_vis[ipt, 1] = t_vis
joints_3d_vis[ipt, 2] = 0
center, scale = self._box2cs(obj['clean_bbox'][:4])
rec.append({
'image': self.image_path_from_index(index, im_ann),
'center': center,
'scale': scale,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
'filename': '',
'imgnum': 0,
})
return rec
def _box2cs(self, box):
x, y, w, h = box[:4]
return self._xywh2cs(x, y, w, h)
def _xywh2cs(self, x, y, w, h):
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array(
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25
return center, scale
def image_path_from_index(self, index, im_ann):
""" example: images / train2017 / 000000119993.jpg """
if 'coco' in self.root:
file_name = '%012d.jpg' % index
if '2014' in self.image_set:
file_name = 'COCO_%s_' % self.image_set + file_name
prefix = 'test2017' if 'test' in self.image_set else self.image_set
data_name = prefix
image_path = os.path.join(
self.root, 'images', data_name, file_name)
return image_path
elif 'crowdpose' in self.root:
file_name = f'{index}.jpg'
image_path = os.path.join(
self.root, 'images', file_name)
return image_path
elif 'aic' in self.root:
file_name = im_ann["file_name"]
image_path = os.path.join(
self.root, 'ai_challenger_keypoint_train_20170902', 'keypoint_train_images_20170902', file_name)
return image_path
def flip_back(output_flipped, matched_parts):
'''
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
'''
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def fliplr_joints(joints, joints_vis, width, matched_parts):
"""
flip coords
"""
# Flip horizontal
joints[:, 0] = width - joints[:, 0] - 1
# Change left-right parts
for pair in matched_parts:
joints[pair[0], :], joints[pair[1], :] = \
joints[pair[1], :], joints[pair[0], :].copy()
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
return joints*joints_vis, joints_vis
def get_affine_transform(
center, scale, rot, output_size,
shift=np.array([0, 0], dtype=np.float32), inv=0
):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
print(scale)
scale = np.array([scale, scale])
scale_tmp = scale * 200.0
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
class CrowdPoseDataset(COCODataset):
def __init__(self, root, image_set, is_train, max_prompt_num=5, min_prompt_num=1,
radius=10, size=256, transparency=0.0, sample_weight=1.0, transform=None):
super().__init__(root, image_set, is_train, max_prompt_num, min_prompt_num,
radius, size, transparency, sample_weight, transform)
self.keypoints_type = {
0: 'left_shoulder',
1: 'right_shoulder',
2: 'left_elbow',
3: 'right_elbow',
4: 'left_wrist',
5: 'right_wrist',
6: 'left_hip',
7: 'right_hip',
8: 'left_knee',
9: 'right_knee',
10: 'left_ankle',
11: 'right_ankle',
12: 'top_head',
13: 'neck'
}
self.num_joints = 14
self.prob_half_body = -1
self.flip_pairs = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]]
self.parent_ids = None
self.upper_body_ids = (0, 1, 2, 3, 4, 5, 12, 13)
self.lower_body_ids = (6, 7, 8, 9, 10, 11)
self.db = self._get_db()
logger.info('=> load {} samples'.format(len(self.db)))
class AICDataset(COCODataset):
def __init__(self, root, image_set, is_train, max_prompt_num=5, min_prompt_num=1,
radius=10, size=256, transparency=0.0, sample_weight=1.0, transform=None):
super().__init__(root, image_set, is_train, max_prompt_num, min_prompt_num,
radius, size, transparency, sample_weight, transform)
self.keypoints_type = {
0: "right_shoulder",
1: "right_elbow",
2: "right_wrist",
3: "left_shoulder",
4: "left_elbow",
5: "left_wrist",
6: "right_hip",
7: "right_knee",
8: "right_ankle",
9: "left_hip",
10: "left_knee",
11: "left_ankle",
12: "head_top",
13: "neck"
}
self.num_joints = 14
self.prob_half_body = -1
self.flip_pairs = [[0, 3], [1, 4], [2, 5], [6, 9], [7, 10], [8, 11]]
self.parent_ids = None
self.upper_body_ids = (0, 1, 2, 3, 4, 5, 12, 13)
self.lower_body_ids = (6, 7, 8, 9, 10, 11)
self.db = self._get_db()
logger.info('=> load {} samples'.format(len(self.db)))
class MPIIDataset(PoseDataset):
def __init__(self, root, image_set, is_train, max_prompt_num=5, min_prompt_num=1,
radius=10, size=256, transparency=0.0, sample_weight=1.0, transform=None):
super().__init__(root, image_set, is_train, max_prompt_num, min_prompt_num,
radius, size, transparency, sample_weight, transform)
self.keypoints_type = {
0: 'right_ankle',
1: 'right_knee',
2: 'right_hip',
3: 'left_hip',
4: 'left_knee',
5: 'left_ankle',
6: 'pelvis',
7: 'thorax',
8: 'upper_neck',
9: 'head_top',
10: 'right_wrist',
11: 'right_elbow',
12: 'right_shoulder',
13: 'left_shoulder',
14: 'left_elbow',
15: 'left_wrist'
}
self.data_format = 'jpg'
self.num_joints = 16
self.prob_half_body = -1
self.flip_pairs = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
self.parent_ids = None
self.upper_body_ids = (7, 8, 9, 10, 11, 12, 13, 14, 15)
self.lower_body_ids = (0, 1, 2, 3, 4, 5, 6)
self.db = self._get_db()
logger.info('=> load {} samples'.format(len(self.db)))
def _get_db(self):
# create train/val split
file_name = os.path.join(
self.root, 'annot', self.image_set+'.json'
)
with open(file_name) as anno_file:
anno = json.load(anno_file)
gt_db = []
for a in anno:
image_name = a['image']
c = np.array(a['center'], dtype=np.float32)
s = np.array([a['scale'], a['scale']], dtype=np.float32)
# Adjust center/scale slightly to avoid cropping limbs
if c[0] != -1:
c[1] = c[1] + 15 * s[1]
s = s * 1.25
# MPII uses matlab format, index is based 1,
# we should first convert to 0-based index
c = c - 1
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float32)
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float32)
if self.image_set != 'test':
joints = np.array(a['joints'])
joints[:, 0:2] = joints[:, 0:2] - 1
joints_vis = np.array(a['joints_vis'])
assert len(joints) == self.num_joints, \
'joint num diff: {} vs {}'.format(len(joints),
self.num_joints)
joints_3d[:, 0:2] = joints[:, 0:2]
joints_3d_vis[:, 0] = joints_vis[:]
joints_3d_vis[:, 1] = joints_vis[:]
image_dir = 'images.zip@' if self.data_format == 'zip' else 'images'
gt_db.append(
{
'image': os.path.join(self.root, image_dir, image_name),
'center': c,
'scale': s,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
'filename': '',
'imgnum': 0,
}
)
return gt_db
================================================
FILE: models/InstructDiffusion/dataset/prompt/color_list_train_small.txt
================================================
Red 纯红 #FF0000 255,0,0
Purple 紫色 #800080 128,0,128
Blue 纯蓝 #0000FF 0,0,255
Green 纯绿 #008000 0,128,0
Yellow 纯黄 #FFFF00 255,255,0
White 纯白 #FFFFFF 255,255,255
Black 纯黑 #000000 0,0,0
Gray 灰色 #808080 128,128,128
================================================
FILE: models/InstructDiffusion/dataset/prompt/prompt_deblur.txt
================================================
Sharpen this blurry image
Increase the sharpness of this unclear photo
Correct the lack of focus in this misty picture
Heighten the definition of this smeared image
Clear up this fuzzy picture
Refine this indistinct photograph
Improve the focus of this hazy image
Amend the softness of this out-of-focus photograph
Polish the murkiness of this low-definition photo
Rectify the vagueness of this blurred image
================================================
FILE: models/InstructDiffusion/dataset/prompt/prompt_denoise.txt
================================================
Remove noise from this image
Eliminate the noise in this picture
Purify this photo by removing noise
Clear up the image by filtering out noise
Eradicate the noise from this photograph
Minimize the noise present in this picture
Cancel out the noise within this image
Clean this photo by discarding the noise
Suppress the noise in this visual representation
Rectify the noise interference in this image
================================================
FILE: models/InstructDiffusion/dataset/prompt/prompt_dewatermark.txt
================================================
Remove watermark from this picture
Erase the watermark from this photograph.
Extract the watermark from this image.
Take out the watermark overlay from this photo.
Wipe off the watermark imprint on this image.
Detach the watermark from this visual representation.
Get rid of the watermarking on this picture.
Withdraw the watermark applied to this photograph.
Clean up this image by deleting the watermark.
Unmark this photo by removing the watermark.
================================================
FILE: models/InstructDiffusion/dataset/prompt/prompt_pose.txt
================================================
Circle the {joint} of the people with the color {color},
Use the {color} color to draw circles around the {joint} of the people,
Make {color} circles around the {joint} of the people,
Put {color} circles on the {joint} of the people,
Draw {color} circles over the {joint} of the people,
Surround the {joint} of the people with {color} circles,
Use the color {color} to make circles on the {joint} of the people,
Mark the {joint} of the people with {color} circles,
Create {color} circles around the {joint} of the people,
Use the color {color} to encircle the {joint} of the people,
================================================
FILE: models/InstructDiffusion/dataset/prompt/prompt_seg.txt
================================================
Mark the pixels of {object} in {color} and leave the rest unchanged.
Color the {object}'s pixels in {color}, keeping the remaining pixels unaltered.
Apply {color} to the pixels of {object} while maintaining the current state of other pixels.
Assign {color} to the pixels belonging to {object}, preserving the rest as they are.
For {object}, set its pixels to {color} and let the others remain the same.
Modify the pixels of {object} to {color} without affecting any other pixels.
Set the {object} pixels to {color} and keep the other pixels in their original state.
Update the pixels of {object} to {color}, but leave the other pixels untouched.
Fill in the pixels of {object} with {color}, retaining the existing colors of the remaining pixels.
Change the {object} pixels to {color}, while keeping the other pixels constant.
Paint the pixels of {object} in {color} and maintain the current appearance of the other pixels.
================================================
FILE: models/InstructDiffusion/dataset/seg/coco_stuff.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Binxin Yang (tennyson@mail.ustc.edu.cn)
# --------------------------------------------------------
from __future__ import annotations
import json
import math
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch.utils.data import Dataset
import cv2
import os
import random
import copy
from glob import glob
class COCOStuffDataset(Dataset):
def __init__(
self,
path: str,
path_edit: str = "None",
split: str = "train",
splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
crop_res: int = 256,
flip_prob: float = 0.0,
transparency: float = 0,
batch_size: int = 10,
empty_percentage: float = 0,
):
assert split in ("train2017", "val2017")
assert sum(splits) == 1
self.split = split
self.path = path
self.path_edit = path_edit
self.batch_size = batch_size
self.crop_res = crop_res
self.flip_prob = flip_prob
self.empty_percentage = empty_percentage
self.transparency = transparency
if self.split in ["train2017", "val2017"]:
file_list = sorted(glob(os.path.join(self.path, "images", self.split, "*.jpg")))
assert len(file_list) > 0, "{} has no image".format(
os.path.join(self.path, "images", self.split)
)
file_list = [f.split("/")[-1].replace(".jpg", "") for f in file_list]
self.files = file_list
else:
raise ValueError("Invalid split name: {}".format(self.split))
seg_diverse_prompt_path = 'dataset/prompt/prompt_seg.txt'
self.seg_diverse_prompt_list=[]
with open(seg_diverse_prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.seg_diverse_prompt_list.append(line)
line=f.readline()
color_list_file_path='dataset/prompt/color_list_train_small.txt'
self.color_list=[]
with open(color_list_file_path) as f:
line = f.readline()
while line:
line_split = line.strip('\n').split(" ")
if len(line_split)>1:
temp = []
for i in range(4):
temp.append(line_split[i])
self.color_list.append(temp)
line = f.readline()
coco_label_list_path = self.path + '/labels.txt'
self.label_dict={}
with open(coco_label_list_path) as f:
line = f.readline()
while line:
line_split = line.strip('\n').split(": ")
self.label_dict[int(line_split[0])]=line_split[1]
line = f.readline()
def __len__(self) -> int:
length=len(self.files)
return length
def _augmentation_new(self, image, label):
# Cropping
h, w = label.shape
if h > w:
start_h = random.randint(0, h - w)
end_h = start_h + w
image = image[start_h:end_h]
label = label[start_h:end_h]
elif h < w:
start_w = random.randint(0, w - h)
end_w = start_w + h
image = image[:, start_w:end_w]
label = label[:, start_w:end_w]
else:
pass
image = Image.fromarray(image).resize((self.crop_res, self.crop_res), resample=Image.Resampling.LANCZOS)
image = np.asarray(image, dtype=np.uint8)
label = Image.fromarray(label).resize((self.crop_res, self.crop_res), resample=Image.Resampling.NEAREST)
label = np.asarray(label, dtype=np.int64)
return image, label
def __getitem__(self, i):
image_id = self.files[i]
img_path = os.path.join(self.path, "images", self.split, image_id + ".jpg")
mask_path = os.path.join(self.path, "annotations", self.split, image_id + ".png")
label = Image.open(mask_path).convert("L")
image = Image.open(img_path).convert("RGB")
label = np.asarray(label)
image = np.asarray(image)
image, label = self._augmentation_new(image,label)
label_list = np.unique(label)
label_list = list(label_list)
label_list_rest = [i for i in range(182)]
for item in label_list_rest:
if item in label_list:
label_list_rest.remove(item)
if 255 in label_list:
label_list.remove(255)
if len(label_list)!=0:
label_idx = random.choice(label_list)
if random.uniform(0, 1) < self.empty_percentage:
label_idx = random.choice(label_list_rest)
class_name = self.label_dict[label_idx+1]
prompt = random.choice(self.seg_diverse_prompt_list)
color = random.choice(self.color_list)
color_name = color[0]
prompt = prompt.format(color=color_name.lower(), object=class_name.lower())
R, G, B = color[3].split(",")
R = int(R)
G = int(G)
B = int(B)
else:
label_idx = 200
prompt = "leave the picture as it is."
mask = (label==label_idx)
image_0 = Image.fromarray(image)
image_1 = copy.deepcopy(image)
if len(label_list)!=0:
image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R
image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G
image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B
image_1 = Image.fromarray(image_1)
# return image_0, image_1, prompt
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
mask = torch.tensor(mask).float()
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/seg/grefcoco.py
================================================
"""
grefer v0.1
This interface provides access to gRefCOCO.
The following API functions are defined:
G_REFER - REFER api class
getRefIds - get ref ids that satisfy given filter conditions.
getAnnIds - get ann ids that satisfy given filter conditions.
getImgIds - get image ids that satisfy given filter conditions.
getCatIds - get category ids that satisfy given filter conditions.
loadRefs - load refs with the specified ref ids.
loadAnns - load anns with the specified ann ids.
loadImgs - load images with the specified image ids.
loadCats - load category names with the specified category ids.
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
showRef - show image, segmentation or box of the referred object with the ref
getMaskByRef - get mask and area of the referred object given ref or ref ids
getMask - get mask and area of the referred object given ref
showMask - show mask of the referred object given ref
"""
import os.path as osp
import json
import pickle
import time
import itertools
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
import numpy as np
from pycocotools import mask
class G_REFER:
def __init__(self, data_root, dataset='grefcoco', splitBy='unc'):
# provide data_root folder which contains grefcoco
print('loading dataset %s into memory...' % dataset)
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
self.DATA_DIR = osp.join(data_root, dataset)
if dataset in ['grefcoco']:
self.IMAGE_DIR = osp.join(data_root, 'images/train2014')
else:
raise KeyError('No refer dataset is called [%s]' % dataset)
tic = time.time()
# load refs from data/dataset/refs(dataset).json
self.data = {}
self.data['dataset'] = dataset
ref_file = osp.join(self.DATA_DIR, f'grefs({splitBy}).p')
if osp.exists(ref_file):
self.data['refs'] = pickle.load(open(ref_file, 'rb'),fix_imports=True)
else:
ref_file = osp.join(self.DATA_DIR, f'grefs({splitBy}).json')
if osp.exists(ref_file):
self.data['refs'] = json.load(open(ref_file, 'rb'))
else:
raise FileNotFoundError('JSON file not found')
# load annotations from data/dataset/instances.json
instances_file = osp.join(self.DATA_DIR, 'instances.json')
instances = json.load(open(instances_file, 'r'))
self.data['images'] = instances['images']
self.data['annotations'] = instances['annotations']
self.data['categories'] = instances['categories']
# create index
self.createIndex()
print('DONE (t=%.2fs)' % (time.time()-tic))
@staticmethod
def _toList(x):
return x if isinstance(x, list) else [x]
@staticmethod
def match_any(a, b):
a = a if isinstance(a, list) else [a]
b = b if isinstance(b, list) else [b]
return set(a) & set(b)
def createIndex(self):
# create sets of mapping
# 1) Refs: {ref_id: ref}
# 2) Anns: {ann_id: ann}
# 3) Imgs: {image_id: image}
# 4) Cats: {category_id: category_name}
# 5) Sents: {sent_id: sent}
# 6) imgToRefs: {image_id: refs}
# 7) imgToAnns: {image_id: anns}
# 8) refToAnn: {ref_id: ann}
# 9) annToRef: {ann_id: ref}
# 10) catToRefs: {category_id: refs}
# 11) sentToRef: {sent_id: ref}
# 12) sentToTokens: {sent_id: tokens}
print('creating index...')
# fetch info from instances
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
Anns[-1] = None
for ann in self.data['annotations']:
Anns[ann['id']] = ann
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
for img in self.data['images']:
Imgs[img['id']] = img
for cat in self.data['categories']:
Cats[cat['id']] = cat['name']
# fetch info from refs
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
Sents, sentToRef, sentToTokens = {}, {}, {}
availableSplits = []
for ref in self.data['refs']:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
category_id = ref['category_id']
image_id = ref['image_id']
if ref['split'] not in availableSplits:
availableSplits.append(ref['split'])
# add mapping related to ref
if ref_id in Refs:
print('Duplicate ref id')
Refs[ref_id] = ref
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
category_id = self._toList(category_id)
added_cats = []
for cat in category_id:
if cat not in added_cats:
added_cats.append(cat)
catToRefs[cat] = catToRefs.get(cat, []) + [ref]
ann_id = self._toList(ann_id)
refToAnn[ref_id] = [Anns[ann] for ann in ann_id]
for ann_id_n in ann_id:
annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref]
# add mapping of sent
for sent in ref['sentences']:
Sents[sent['sent_id']] = sent
sentToRef[sent['sent_id']] = ref
sentToTokens[sent['sent_id']] = sent['tokens']
# create class members
self.Refs = Refs
self.Anns = Anns
self.Imgs = Imgs
self.Cats = Cats
self.Sents = Sents
self.imgToRefs = imgToRefs
self.imgToAnns = imgToAnns
self.refToAnn = refToAnn
self.annToRef = annToRef
self.catToRefs = catToRefs
self.sentToRef = sentToRef
self.sentToTokens = sentToTokens
self.availableSplits = availableSplits
print('index created.')
def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
image_ids = self._toList(image_ids)
cat_ids = self._toList(cat_ids)
split = self._toList(split)
for s in split:
if s not in self.availableSplits:
raise ValueError(f'Invalid split name: {s}')
refs = self.data['refs']
if len(image_ids) > 0:
lists = [self.imgToRefs[image_id] for image_id in image_ids]
refs = list(itertools.chain.from_iterable(lists))
if len(cat_ids) > 0:
refs = [ref for ref in refs if self.match_any(ref['category_id'], cat_ids)]
if len(split) > 0:
refs = [ref for ref in refs if ref['split'] in split]
ref_ids = [ref['ref_id'] for ref in refs]
return ref_ids
def getAnnIds(self, image_ids=[], ref_ids=[]):
image_ids = self._toList(image_ids)
ref_ids = self._toList(ref_ids)
if any([len(image_ids), len(ref_ids)]):
if len(image_ids) > 0:
lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns]
anns = list(itertools.chain.from_iterable(lists))
else:
anns = self.data['annotations']
ann_ids = [ann['id'] for ann in anns]
if len(ref_ids) > 0:
lists = [self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]
anns_by_ref_id = list(itertools.chain.from_iterable(lists))
ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id)))
else:
ann_ids = [ann['id'] for ann in self.data['annotations']]
return ann_ids
def getImgIds(self, ref_ids=[]):
ref_ids = self._toList(ref_ids)
if len(ref_ids) > 0:
image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
else:
image_ids = self.Imgs.keys()
return image_ids
def getCatIds(self):
return self.Cats.keys()
def loadRefs(self, ref_ids=[]):
return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)]
def loadAnns(self, ann_ids=[]):
if isinstance(ann_ids, str):
ann_ids = int(ann_ids)
return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)]
def loadImgs(self, image_ids=[]):
return [self.Imgs[image_id] for image_id in self._toList(image_ids)]
def loadCats(self, cat_ids=[]):
return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)]
def getRefBox(self, ref_id):
anns = self.refToAnn[ref_id]
return [ann['bbox'] for ann in anns] # [x, y, w, h]
def showRef(self, ref, seg_box='seg'):
ax = plt.gca()
# show image
image = self.Imgs[ref['image_id']]
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
ax.imshow(I)
# show refer expression
for sid, sent in enumerate(ref['sentences']):
print('%s. %s' % (sid+1, sent['sent']))
# show segmentations
if seg_box == 'seg':
ann_id = ref['ann_id']
ann = self.Anns[ann_id]
polygons = []
color = []
c = 'none'
if type(ann['segmentation'][0]) == list:
# polygon used for refcoco*
for seg in ann['segmentation']:
poly = np.array(seg).reshape((len(seg)/2, 2))
polygons.append(Polygon(poly, True, alpha=0.4))
color.append(c)
p = PatchCollection(polygons, facecolors=color, edgecolors=(1,1,0,0), linewidths=3, alpha=1)
ax.add_collection(p) # thick yellow polygon
p = PatchCollection(polygons, facecolors=color, edgecolors=(1,0,0,0), linewidths=1, alpha=1)
ax.add_collection(p) # thin red polygon
else:
# mask used for refclef
rle = ann['segmentation']
m = mask.decode(rle)
img = np.ones( (m.shape[0], m.shape[1], 3) )
color_mask = np.array([2.0,166.0,101.0])/255
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack( (img, m*0.5) ))
# show bounding-box
elif seg_box == 'box':
ann_id = ref['ann_id']
ann = self.Anns[ann_id]
bbox = self.getRefBox(ref['ref_id'])
box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
ax.add_patch(box_plot)
def getMask(self, ann):
if not ann:
return None
if ann['iscrowd']:
raise ValueError('Crowd object')
image = self.Imgs[ann['image_id']]
if type(ann['segmentation'][0]) == list: # polygon
rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
else:
rle = ann['segmentation']
m = mask.decode(rle)
m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
# compute area
area = sum(mask.area(rle)) # should be close to ann['area']
return {'mask': m, 'area': area}
def getMaskByRef(self, ref=None, ref_id=None, merge=False):
if not ref and not ref_id:
raise ValueError
if ref:
ann_ids = ref['ann_id']
ref_id = ref['ref_id']
else:
ann_ids = self.getAnnIds(ref_ids=ref_id)
if ann_ids == [-1]:
img = self.Imgs[self.Refs[ref_id]['image_id']]
return {
'mask': np.zeros([img['height'], img['width']], dtype=np.uint8),
'empty': True
}
anns = self.loadAnns(ann_ids)
mask_list = [self.getMask(ann) for ann in anns if not ann['iscrowd']]
if merge:
merged_masks = sum([mask['mask'] for mask in mask_list])
merged_masks[np.where(merged_masks>1)] = 1
return {
'mask': merged_masks,
'empty': False
}
else:
return mask_list
def showMask(self, ref):
M = self.getMask(ref)
msk = M['mask']
ax = plt.gca()
ax.imshow(msk)
================================================
FILE: models/InstructDiffusion/dataset/seg/grefcoco_segmentation.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Binxin Yang (tennyson@mail.ustc.edu.cn)
# --------------------------------------------------------
from __future__ import annotations
import os
import random
import copy
import json
import math
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch.utils.data import Dataset
from dataset.seg.grefcoco import G_REFER
class GrefCOCODataset(Dataset):
def __init__(
self,
path: str,
split: str = "train",
min_resize_res: int = 256,
max_resize_res: int = 256,
crop_res: int = 256,
flip_prob: float = 0.0,
transparency: float = 0.0,
test: bool = False,
):
assert split in ("train", "val", "test")
self.path = path
self.min_resize_res = min_resize_res
self.max_resize_res = max_resize_res
self.crop_res = crop_res
self.flip_prob = flip_prob
self.G_ref_dataset=G_REFER(data_root=path)
self.IMAGE_DIR = os.path.join(path, 'images/train2014')
self.list_ref=self.G_ref_dataset.getRefIds(split=split)
self.transparency = transparency
self.test = test
seg_diverse_prompt_path = 'dataset/prompt/prompt_seg.txt'
self.seg_diverse_prompt_list=[]
with open(seg_diverse_prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.seg_diverse_prompt_list.append(line)
line=f.readline()
color_list_file_path='dataset/prompt/color_list_train_small.txt'
self.color_list=[]
with open(color_list_file_path) as f:
line = f.readline()
while line:
line_split = line.strip('\n').split(" ")
if len(line_split)>1:
temp = []
for i in range(4):
temp.append(line_split[i])
self.color_list.append(temp)
line = f.readline()
def __len__(self) -> int:
return len(self.list_ref)
def _augmentation_new(self, image, label):
# Cropping
h, w = label.shape
if h > w:
start_h = random.randint(0, h - w)
end_h = start_h + w
image = image[start_h:end_h]
label = label[start_h:end_h]
elif h < w:
start_w = random.randint(0, w - h)
end_w = start_w + h
image = image[:, start_w:end_w]
label = label[:, start_w:end_w]
else:
pass
image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS)
image = np.asarray(image, dtype=np.uint8)
label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST)
label = np.asarray(label, dtype=np.int64)
return image, label
def __getitem__(self, i: int) -> dict[str, Any]:
ref_ids = self.list_ref[i]
ref = self.G_ref_dataset.loadRefs(ref_ids)[0]
sentences = random.choice(ref['sentences'])['sent']
prompt = random.choice(self.seg_diverse_prompt_list)
color = random.choice(self.color_list)
color_name = color[0]
prompt = prompt.format(color=color_name.lower(), object=sentences.lower())
R, G, B = color[3].split(",")
R = int(R)
G = int(G)
B = int(B)
image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name']
image_path = os.path.join(self.IMAGE_DIR,image_name)
mask = self.G_ref_dataset.getMaskByRef(ref=ref,merge=True)['mask']
image = Image.open(image_path).convert("RGB")
image = np.asarray(image)
image, mask = self._augmentation_new(image,mask)
mask = (mask == 1)
image_0 = Image.fromarray(image)
image_1 = copy.deepcopy(image)
image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R
image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G
image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B
image_1 = Image.fromarray(image_1)
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
mask = torch.tensor(mask).float()
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/seg/refcoco.py
================================================
__author__ = 'licheng'
"""
This interface provides access to four datasets:
1) refclef
2) refcoco
3) refcoco+
4) refcocog
split by unc and google
The following API functions are defined:
REFER - REFER api class
getRefIds - get ref ids that satisfy given filter conditions.
getAnnIds - get ann ids that satisfy given filter conditions.
getImgIds - get image ids that satisfy given filter conditions.
getCatIds - get category ids that satisfy given filter conditions.
loadRefs - load refs with the specified ref ids.
loadAnns - load anns with the specified ann ids.
loadImgs - load images with the specified image ids.
loadCats - load category names with the specified category ids.
getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
showRef - show image, segmentation or box of the referred object with the ref
getMask - get mask and area of the referred object given ref
showMask - show mask of the referred object given ref
"""
import sys
sys.path.append("./dataset")
import os.path as osp
import json
import pickle
import time
import itertools
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon, Rectangle
from pprint import pprint
import numpy as np
from pycocotools import mask
# import cv2
# from skimage.measure import label, regionprops
class REFER:
def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
# provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
# also provide dataset name and splitBy information
# e.g., dataset = 'refcoco', splitBy = 'unc'
print('loading dataset %s into memory...' % dataset)
self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
self.DATA_DIR = osp.join(data_root, dataset)
if dataset in ['refcoco', 'refcoco+', 'refcocog']:
self.IMAGE_DIR = osp.join(data_root, 'images/mscoco/images/train2014')
elif dataset == 'refclef':
self.IMAGE_DIR = osp.join(data_root, 'images/saiapr_tc-12')
else:
print('No refer dataset is called [%s]' % dataset)
sys.exit()
# load refs from data/dataset/refs(dataset).json
tic = time.time()
ref_file = osp.join(self.DATA_DIR, 'refs('+splitBy+').p')
self.data = {}
self.data['dataset'] = dataset
self.data['refs'] = pickle.load(open(ref_file, 'rb'),fix_imports=True)
# load annotations from data/dataset/instances.json
instances_file = osp.join(self.DATA_DIR, 'instances.json')
instances = json.load(open(instances_file, 'r'))
self.data['images'] = instances['images']
self.data['annotations'] = instances['annotations']
self.data['categories'] = instances['categories']
# create index
self.createIndex()
print('DONE (t=%.2fs)' % (time.time()-tic))
def createIndex(self):
# create sets of mapping
# 1) Refs: {ref_id: ref}
# 2) Anns: {ann_id: ann}
# 3) Imgs: {image_id: image}
# 4) Cats: {category_id: category_name}
# 5) Sents: {sent_id: sent}
# 6) imgToRefs: {image_id: refs}
# 7) imgToAnns: {image_id: anns}
# 8) refToAnn: {ref_id: ann}
# 9) annToRef: {ann_id: ref}
# 10) catToRefs: {category_id: refs}
# 11) sentToRef: {sent_id: ref}
# 12) sentToTokens: {sent_id: tokens}
print('creating index...')
# fetch info from instances
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
for ann in self.data['annotations']:
Anns[ann['id']] = ann
imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'], []) + [ann]
for img in self.data['images']:
Imgs[img['id']] = img
for cat in self.data['categories']:
Cats[cat['id']] = cat['name']
# fetch info from refs
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
Sents, sentToRef, sentToTokens = {}, {}, {}
for ref in self.data['refs']:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
category_id = ref['category_id']
image_id = ref['image_id']
# add mapping related to ref
Refs[ref_id] = ref
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
refToAnn[ref_id] = Anns[ann_id]
annToRef[ann_id] = ref
# add mapping of sent
for sent in ref['sentences']:
Sents[sent['sent_id']] = sent
sentToRef[sent['sent_id']] = ref
sentToTokens[sent['sent_id']] = sent['tokens']
# create class members
self.Refs = Refs
self.Anns = Anns
self.Imgs = Imgs
self.Cats = Cats
self.Sents = Sents
self.imgToRefs = imgToRefs
self.imgToAnns = imgToAnns
self.refToAnn = refToAnn
self.annToRef = annToRef
self.catToRefs = catToRefs
self.sentToRef = sentToRef
self.sentToTokens = sentToTokens
print('index created.')
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
image_ids = image_ids if type(image_ids) == list else [image_ids]
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
if len(image_ids)==len(cat_ids)==len(ref_ids)==len(split)==0:
refs = self.data['refs']
else:
if not len(image_ids) == 0:
refs = [self.imgToRefs[image_id] for image_id in image_ids]
else:
refs = self.data['refs']
if not len(cat_ids) == 0:
refs = [ref for ref in refs if ref['category_id'] in cat_ids]
if not len(ref_ids) == 0:
refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
if not len(split) == 0:
if split in ['testA', 'testB', 'testC']:
refs = [ref for ref in refs if split[-1] in ref['split']] # we also consider testAB, testBC, ...
elif split in ['testAB', 'testBC', 'testAC']:
refs = [ref for ref in refs if ref['split'] == split] # rarely used I guess...
elif split == 'test':
refs = [ref for ref in refs if 'test' in ref['split']]
elif split == 'train' or split == 'val':
refs = [ref for ref in refs if ref['split'] == split]
else:
print('No such split [%s]' % split)
sys.exit()
ref_ids = [ref['ref_id'] for ref in refs]
return ref_ids
def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
image_ids = image_ids if type(image_ids) == list else [image_ids]
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
ann_ids = [ann['id'] for ann in self.data['annotations']]
else:
if not len(image_ids) == 0:
lists = [self.imgToAnns[image_id] for image_id in image_ids if image_id in self.imgToAnns] # list of [anns]
anns = list(itertools.chain.from_iterable(lists))
else:
anns = self.data['annotations']
if not len(cat_ids) == 0:
anns = [ann for ann in anns if ann['category_id'] in cat_ids]
ann_ids = [ann['id'] for ann in anns]
if not len(ref_ids) == 0:
ids = set(ann_ids).intersection(set([self.Refs[ref_id]['ann_id'] for ref_id in ref_ids]))
return ann_ids
def getImgIds(self, ref_ids=[]):
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
if not len(ref_ids) == 0:
image_ids = list(set([self.Refs[ref_id]['image_id'] for ref_id in ref_ids]))
else:
image_ids = self.Imgs.keys()
return image_ids
def getCatIds(self):
return self.Cats.keys()
def loadRefs(self, ref_ids=[]):
if type(ref_ids) == list:
return [self.Refs[ref_id] for ref_id in ref_ids]
elif type(ref_ids) == int:
return [self.Refs[ref_ids]]
def loadAnns(self, ann_ids=[]):
if type(ann_ids) == list:
return [self.Anns[ann_id] for ann_id in ann_ids]
elif type(ann_ids) == int or type(ann_ids) == unicode:
return [self.Anns[ann_ids]]
def loadImgs(self, image_ids=[]):
if type(image_ids) == list:
return [self.Imgs[image_id] for image_id in image_ids]
elif type(image_ids) == int:
return [self.Imgs[image_ids]]
def loadCats(self, cat_ids=[]):
if type(cat_ids) == list:
return [self.Cats[cat_id] for cat_id in cat_ids]
elif type(cat_ids) == int:
return [self.Cats[cat_ids]]
def getRefBox(self, ref_id):
ref = self.Refs[ref_id]
ann = self.refToAnn[ref_id]
return ann['bbox'] # [x, y, w, h]
def showRef(self, ref, seg_box='seg'):
ax = plt.gca()
# show image
image = self.Imgs[ref['image_id']]
I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
ax.imshow(I)
# show refer expression
for sid, sent in enumerate(ref['sentences']):
print('%s. %s' % (sid+1, sent['sent']))
# show segmentations
if seg_box == 'seg':
ann_id = ref['ann_id']
ann = self.Anns[ann_id]
polygons = []
color = []
c = 'none'
if type(ann['segmentation'][0]) == list:
# polygon used for refcoco*
for seg in ann['segmentation']:
poly = np.array(seg).reshape((len(seg)/2, 2))
polygons.append(Polygon(poly, True, alpha=0.4))
color.append(c)
p = PatchCollection(polygons, facecolors=color, edgecolors=(1,1,0,0), linewidths=3, alpha=1)
ax.add_collection(p) # thick yellow polygon
p = PatchCollection(polygons, facecolors=color, edgecolors=(1,0,0,0), linewidths=1, alpha=1)
ax.add_collection(p) # thin red polygon
else:
# mask used for refclef
rle = ann['segmentation']
m = mask.decode(rle)
img = np.ones( (m.shape[0], m.shape[1], 3) )
color_mask = np.array([2.0,166.0,101.0])/255
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack( (img, m*0.5) ))
# show bounding-box
elif seg_box == 'box':
ann_id = ref['ann_id']
ann = self.Anns[ann_id]
bbox = self.getRefBox(ref['ref_id'])
box_plot = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='green', linewidth=3)
ax.add_patch(box_plot)
def getMask(self, ref):
# return mask, area and mask-center
ann = self.refToAnn[ref['ref_id']]
image = self.Imgs[ref['image_id']]
if type(ann['segmentation'][0]) == list: # polygon
rle = mask.frPyObjects(ann['segmentation'], image['height'], image['width'])
else:
rle = ann['segmentation']
m = mask.decode(rle)
m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
# compute area
area = sum(mask.area(rle)) # should be close to ann['area']
return {'mask': m, 'area': area}
# # position
# position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
# position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
# # mass position (if there were multiple regions, we use the largest one.)
# label_m = label(m, connectivity=m.ndim)
# regions = regionprops(label_m)
# if len(regions) > 0:
# largest_id = np.argmax(np.array([props.filled_area for props in regions]))
# largest_props = regions[largest_id]
# mass_y, mass_x = largest_props.centroid
# else:
# mass_x, mass_y = position_x, position_y
# # if centroid is not in mask, we find the closest point to it from mask
# if m[mass_y, mass_x] != 1:
# print 'Finding closes mask point ...'
# kernel = np.ones((10, 10),np.uint8)
# me = cv2.erode(m, kernel, iterations = 1)
# points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
# points = np.array(points)
# dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
# id = np.argsort(dist)[0]
# mass_y, mass_x = points[id]
# # return
# return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
# # show image and mask
# I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
# plt.figure()
# plt.imshow(I)
# ax = plt.gca()
# img = np.ones( (m.shape[0], m.shape[1], 3) )
# color_mask = np.array([2.0,166.0,101.0])/255
# for i in range(3):
# img[:,:,i] = color_mask[i]
# ax.imshow(np.dstack( (img, m*0.5) ))
# plt.show()
def showMask(self, ref):
M = self.getMask(ref)
msk = M['mask']
ax = plt.gca()
ax.imshow(msk)
if __name__ == '__main__':
refer = REFER(dataset='refcocog', splitBy='google')
ref_ids = refer.getRefIds()
print(len(ref_ids))
print(len(refer.Imgs))
print(len(refer.imgToRefs))
ref_ids = refer.getRefIds(split='train')
print('There are %s training referred objects.' % len(ref_ids))
for ref_id in ref_ids:
ref = refer.loadRefs(ref_id)[0]
if len(ref['sentences']) < 2:
continue
pprint(ref)
print('The label is %s.' % refer.Cats[ref['category_id']])
plt.figure()
refer.showRef(ref, seg_box='box')
plt.show()
================================================
FILE: models/InstructDiffusion/dataset/seg/refcoco_segmentation.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Binxin Yang (tennyson@mail.ustc.edu.cn)
# --------------------------------------------------------
from __future__ import annotations
import os
import random
import copy
import json
import math
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch.utils.data import Dataset
from dataset.seg.refcoco import REFER
class RefCOCODataset(Dataset):
def __init__(
self,
path: str,
split: str = "train",
min_resize_res: int = 256,
max_resize_res: int = 256,
crop_res: int = 256,
flip_prob: float = 0.0,
transparency: float = 0.0,
test: bool = False,
):
assert split in ("train", "val", "test")
self.path = path
self.min_resize_res = min_resize_res
self.max_resize_res = max_resize_res
self.crop_res = crop_res
self.flip_prob = flip_prob
self.G_ref_dataset=REFER(data_root=path)
self.IMAGE_DIR = os.path.join(path, 'images/train2014')
self.list_ref=self.G_ref_dataset.getRefIds(split=split)
self.transparency = transparency
self.test = test
seg_diverse_prompt_path = 'dataset/prompt/prompt_seg.txt'
self.seg_diverse_prompt_list=[]
with open(seg_diverse_prompt_path) as f:
line=f.readline()
while line:
line=line.strip('\n')
self.seg_diverse_prompt_list.append(line)
line=f.readline()
color_list_file_path='dataset/prompt/color_list_train_small.txt'
self.color_list=[]
with open(color_list_file_path) as f:
line = f.readline()
while line:
line_split = line.strip('\n').split(" ")
if len(line_split)>1:
temp = []
for i in range(4):
temp.append(line_split[i])
self.color_list.append(temp)
line = f.readline()
def __len__(self) -> int:
return len(self.list_ref)
def _augmentation_new(self, image, label):
# Cropping
h, w = label.shape
if h > w:
start_h = random.randint(0, h - w)
end_h = start_h + w
image = image[start_h:end_h]
label = label[start_h:end_h]
elif h < w:
start_w = random.randint(0, w - h)
end_w = start_w + h
image = image[:, start_w:end_w]
label = label[:, start_w:end_w]
else:
pass
image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS)
image = np.asarray(image, dtype=np.uint8)
label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST)
label = np.asarray(label, dtype=np.int64)
return image, label
def __getitem__(self, i: int) -> dict[str, Any]:
ref_ids = self.list_ref[i]
ref = self.G_ref_dataset.loadRefs(ref_ids)[0]
sentences = random.choice(ref['sentences'])['sent']
prompt = random.choice(self.seg_diverse_prompt_list)
color = random.choice(self.color_list)
color_name = color[0]
prompt = prompt.format(color=color_name.lower(), object=sentences.lower())
R, G, B = color[3].split(",")
R = int(R)
G = int(G)
B = int(B)
image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name']
image_path = os.path.join(self.IMAGE_DIR,image_name)
mask = self.G_ref_dataset.getMask(ref=ref)['mask']
image = Image.open(image_path).convert("RGB")
image = np.asarray(image)
image, mask = self._augmentation_new(image,mask)
mask = (mask == 1)
image_0 = Image.fromarray(image)
image_1 = copy.deepcopy(image)
image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R
image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G
image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B
image_1 = Image.fromarray(image_1)
reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
mask = torch.tensor(mask).float()
crop = torchvision.transforms.RandomCrop(self.crop_res)
flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
================================================
FILE: models/InstructDiffusion/dataset/utils/zip_manager.py
================================================
import zipfile
import os.path as osp
# import lmdb
import logging
from PIL import Image
import pickle
import io
import glob
import os
from pathlib import Path
import time
from threading import Thread
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
home = str(Path.home())
abs_blob_path=os.path.realpath("/mnt/blob/")
CACHE_FOLDER=os.path.join(home,"caching")
USE_CACHE=True
def norm(path):
assert "*" not in path
return os.path.realpath(os.path.abspath(path))
def in_blob(file):
if abs_blob_path in file:
return True
else:
return False
def map_name(file):
path=norm(file)
path=path.lstrip(abs_blob_path+"/")
path=path.replace("/","_")
assert len(path)<250
return path
def preload(db,sync=False):
if sync:
db.initialize()
else:
p = Thread(target=db.initialize)
p.start()
def get_keys_from_lmdb(db):
with db.begin(write=False) as txn:
return list(txn.cursor().iternext(values=False))
def decode_img(byteflow):
try:
img=Image.open(io.BytesIO(byteflow)).convert("RGB")
img.load()
except:
img = Image.open("white.jpeg").convert("RGB")
img.load()
return img
def decode_text(byteflow):
return pickle.loads(byteflow)
decode_funcs={
"image": decode_img,
"text": decode_text
}
class ZipManager:
def __init__(self, zip_path,data_type,prefix=None) -> None:
self.decode_func=decode_funcs[data_type]
self.zip_path=zip_path
self._init=False
preload(self)
def deinitialze(self):
self.zip_fd.close()
del self.zip_fd
self._init = False
def initialize(self,close=True):
self.zip_fd = zipfile.ZipFile(self.zip_path, mode="r")
if not hasattr(self,"_keys"):
self._keys = self.zip_fd.namelist()
self._init = True
if close:
self.deinitialze()
@property
def keys(self):
while not hasattr(self,"_keys"):
time.sleep(0.1)
return self._keys
def get(self, name):
if not self._init:
self.initialize(close=False)
byteflow = self.zip_fd.read(name)
return self.decode_func(byteflow)
class MultipleZipManager:
def __init__(self, files: list, data_type, sync=True):
self.files = files
self._is_init = False
self.data_type=data_type
if sync:
print("sync",files)
self.initialize()
else:
print("async",files)
preload(self)
print("initialize over")
def initialize(self):
self.mapping={}
self.managers={}
for file in self.files:
manager = ZipManager(file, self.data_type)
self.managers[file]=manager
for file,manager in self.managers.items():
print(file)
# print("loading")
logging.info(f"{file} loading")
keys=manager.keys
for key in keys:
self.mapping[key]=file
logging.info(f"{file} loaded, size = {len(keys)}")
print("loaded")
self._keys=list(self.mapping.keys())
self._is_init=True
@property
def keys(self):
while not self._is_init:
time.sleep(0.1)
return self._keys
def get(self, name):
data = self.managers[self.mapping[name]].get(name)
return data
================================================
FILE: models/InstructDiffusion/edit_app.py
================================================
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Tiankai Hang (tkhang@seu.edu.cn)
# --------------------------------------------------------
import os
import sys
import re
import math
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch import autocast
import einops
from einops import rearrange
import gradio as gr
import k_diffusion as K
import requests
from functools import partial
from copy import deepcopy
from PIL import Image, ImageOps
import click
sys.path.append("./stable_diffusion")
from stable_diffusion.ldm.util import instantiate_from_config
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
model = instantiate_from_config(config.model)
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if 'state_dict' in pl_sd:
pl_sd = pl_sd['state_dict']
m, u = model.load_state_dict(pl_sd, strict=False)
print(m, u)
return model
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
def get_header():
content = """
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div style="
display: inline-flex;
gap: 0.8rem;
font-size: 1.75rem;
justify-content: center;
margin-bottom: 10px;
">
<h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;">
InstructDiffusion 🎨
</h1>
</div>
<div>
<p style="align-items: center; margin-bottom: 7px;">
InstructDiffusion, upload a source image and write the instruction to conduct keypoint detection, referring segmentation, and image editing.
</p>
<p style="align-items: center; margin-bottom: 7px;">
Paper is available in <a style="text-decoration: underline;" href="https://gengzigang.github.io/instructdiffusion.github.io/">Arxiv</a>. If you like this demo, please help to ⭐ the <a style="text-decoration: underline;" href="https://github.com/cientgu/InstructDiffusion">Github Repo</a> 😊.
</p>
</div>
</div>
"""
return content
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
cfg_cond = {
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], cond["c_crossattn"][0]])],
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
}
out_cond, out_img_cond, out_txt_cond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
return 0.5 * (out_img_cond + out_txt_cond) + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_cond - out_txt_cond)
def predict(
model, model_wrap,
model_wrap_cfg,
null_token, resolution,
input_img, edit, seed, steps, cfg_text, cfg_image,
stochastic_steps=0, sampler="euler", additional={}):
# set seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.empty_cache()
if isinstance(input_img, str):
if input_img.startswith("http"):
input_image = Image.open(requests.get(input_img, stream=True).raw).convert("RGB")
else:
input_image = Image.open(input_img).convert("RGB")
width, height = input_image.size
factor = resolution / max(width, height)
width = int((width * factor) // 64) * 64
height = int((height * factor) // 64) * 64
if hasattr(Image, "Resampling"):
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
else:
input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
# if PIL Image
elif isinstance(input_img, Image.Image):
input_image = input_img
width, height = input_image.size
factor = resolution / max(width, height)
# factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
width = int((width * factor) // 64) * 64
height = int((height * factor) // 64) * 64
if hasattr(Image, "Resampling"):
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
else:
input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
elif isinstance(input_img, dict):
input_image = input_img["image"].convert("RGB")
width, height = input_image.size
factor = resolution / max(width, height)
width = int((width * factor) // 64) * 64
height = int((height * factor) // 64) * 64
if hasattr(Image, "Resampling"):
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
else:
input_image = ImageOps.fit(input_image, (width, height), method=Image.LANCZOS)
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
input_image = rearrange(input_image, "h w c -> 1 c h w").cuda()
assert input_image is not None
# print input image size
print(input_image.shape, factor, width, height)
with torch.no_grad(), autocast("cuda"):
cond = {}
cond["c_crossattn"] = [model.get_learned_conditioning([edit])]
cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
uncond = {}
if "txt_embed" in additional:
uncond["c_crossattn"] = [additional["txt_embed"].cuda().unsqueeze(0)]
else:
uncond["c_crossattn"] = [null_token]
if "img_embed" in additional:
# uncond["c_concat"] = [additional["img_embed"].cuda()]
# resize to cond["c_concat"][0]
uncond["c_concat"] = [additional["img_embed"].cuda()]
uncond["c_concat"][0] = F.interpolate(uncond["c_concat"][0], size=cond["c_concat"][0].shape[-2:], mode="bilinear", align_corners=False)
else:
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
sigmas = model_wrap.get_sigmas(steps)
extra_args = {
"cond": cond,
"uncond": uncond,
"text_cfg_scale": cfg_text,
"image_cfg_scale": cfg_image,
}
if stochastic_steps <= 0:
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
if sampler == "euler":
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
elif sampler == "heun":
z = K.sampling.sample_heun(model_wrap_cfg, z, sigmas, extra_args=extra_args)
else:
z = torch.randn_like(cond["c_concat"][0]) * sigmas[stochastic_steps] + cond["c_concat"][0]
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas[stochastic_steps:], extra_args=extra_args)
x = model.decode_first_stage(z)
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
x = 255.0 * rearrange(x, "1 c h w -> h w c")
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
# input_image to PIL
input_image = torch.clamp((input_image + 1.0) / 2.0, min=0.0, max=1.0)
input_image = 255.0 * rearrange(input_image, "1 c h w -> h w c")
input_image = Image.fromarray(input_image.type(torch.uint8).cpu().numpy())
return edited_image # , gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
@click.command()
@click.option("--ckpt", type=str, default="checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign.ckpt")
def main(ckpt="checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign.ckpt"):
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: no
gitextract_s8_0ppc8/
├── .gitignore
├── README.md
├── environment/
│ ├── edict_requirements.txt
│ ├── instructdiffusion_requirements.txt
│ ├── masactrl_requirements.txt
│ ├── p2p_requirements.txt
│ ├── pix2pix_zero_requirements.txt
│ └── pnp_requirements.txt
├── evaluation/
│ ├── evaluate.py
│ └── matrics_calculator.py
├── models/
│ ├── InstructDiffusion/
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── configs/
│ │ │ └── instruct_diffusion.yaml
│ │ ├── dataset/
│ │ │ ├── README.md
│ │ │ ├── editing/
│ │ │ │ └── edit_zip_dataset.py
│ │ │ ├── low_level/
│ │ │ │ ├── lowlevel_clwd.py
│ │ │ │ ├── lowlevel_gopro.py
│ │ │ │ ├── lowlevel_reds.py
│ │ │ │ └── lowlevel_sidd.py
│ │ │ ├── pose/
│ │ │ │ └── pose.py
│ │ │ ├── prompt/
│ │ │ │ ├── color_list_train_small.txt
│ │ │ │ ├── prompt_deblur.txt
│ │ │ │ ├── prompt_denoise.txt
│ │ │ │ ├── prompt_dewatermark.txt
│ │ │ │ ├── prompt_pose.txt
│ │ │ │ └── prompt_seg.txt
│ │ │ ├── seg/
│ │ │ │ ├── coco_stuff.py
│ │ │ │ ├── grefcoco.py
│ │ │ │ ├── grefcoco_segmentation.py
│ │ │ │ ├── refcoco.py
│ │ │ │ └── refcoco_segmentation.py
│ │ │ └── utils/
│ │ │ └── zip_manager.py
│ │ ├── edit_app.py
│ │ ├── edit_cli.py
│ │ ├── environment.yaml
│ │ ├── main.py
│ │ ├── scripts/
│ │ │ ├── convert_ckpt.py
│ │ │ ├── download_pretrained_sd.sh
│ │ │ ├── inference_example.sh
│ │ │ └── run_multinode.sh
│ │ ├── stable_diffusion/
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── Stable_Diffusion_v1_Model_Card.md
│ │ │ ├── assets/
│ │ │ │ ├── results.gif.REMOVED.git-id
│ │ │ │ ├── stable-samples/
│ │ │ │ │ ├── img2img/
│ │ │ │ │ │ ├── upscaling-in.png.REMOVED.git-id
│ │ │ │ │ │ └── upscaling-out.png.REMOVED.git-id
│ │ │ │ │ └── txt2img/
│ │ │ │ │ ├── merged-0005.png.REMOVED.git-id
│ │ │ │ │ ├── merged-0006.png.REMOVED.git-id
│ │ │ │ │ └── merged-0007.png.REMOVED.git-id
│ │ │ │ └── txt2img-preview.png.REMOVED.git-id
│ │ │ ├── configs/
│ │ │ │ ├── autoencoder/
│ │ │ │ │ ├── autoencoder_kl_16x16x16.yaml
│ │ │ │ │ ├── autoencoder_kl_32x32x4.yaml
│ │ │ │ │ ├── autoencoder_kl_64x64x3.yaml
│ │ │ │ │ └── autoencoder_kl_8x8x64.yaml
│ │ │ │ ├── latent-diffusion/
│ │ │ │ │ ├── celebahq-ldm-vq-4.yaml
│ │ │ │ │ ├── cin-ldm-vq-f8.yaml
│ │ │ │ │ ├── cin256-v2.yaml
│ │ │ │ │ ├── ffhq-ldm-vq-4.yaml
│ │ │ │ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ │ │ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ │ │ │ └── txt2img-1p4B-eval.yaml
│ │ │ │ ├── retrieval-augmented-diffusion/
│ │ │ │ │ └── 768x768.yaml
│ │ │ │ └── stable-diffusion/
│ │ │ │ └── v1-inference.yaml
│ │ │ ├── environment.yaml
│ │ │ ├── ldm/
│ │ │ │ ├── lr_scheduler.py
│ │ │ │ ├── models/
│ │ │ │ │ ├── autoencoder.py
│ │ │ │ │ └── diffusion/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── classifier.py
│ │ │ │ │ ├── ddim.py
│ │ │ │ │ ├── ddpm.py
│ │ │ │ │ ├── ddpm_edit.py
│ │ │ │ │ ├── dpm_solver/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── dpm_solver.py
│ │ │ │ │ │ └── sampler.py
│ │ │ │ │ └── plms.py
│ │ │ │ ├── modules/
│ │ │ │ │ ├── attention.py
│ │ │ │ │ ├── diffusionmodules/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── model.py
│ │ │ │ │ │ ├── openaimodel.py
│ │ │ │ │ │ └── util.py
│ │ │ │ │ ├── distributions/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── distributions.py
│ │ │ │ │ ├── ema.py
│ │ │ │ │ ├── encoders/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── modules.py
│ │ │ │ │ ├── image_degradation/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── bsrgan.py
│ │ │ │ │ │ ├── bsrgan_light.py
│ │ │ │ │ │ └── utils_image.py
│ │ │ │ │ ├── losses/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── contperceptual.py
│ │ │ │ │ │ └── vqperceptual.py
│ │ │ │ │ └── x_transformer.py
│ │ │ │ └── util.py
│ │ │ ├── main.py
│ │ │ ├── models/
│ │ │ │ ├── first_stage_models/
│ │ │ │ │ ├── kl-f16/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── kl-f32/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── kl-f4/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── kl-f8/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f16/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f4/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f4-noattn/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ ├── vq-f8/
│ │ │ │ │ │ └── config.yaml
│ │ │ │ │ └── vq-f8-n256/
│ │ │ │ │ └── config.yaml
│ │ │ │ └── ldm/
│ │ │ │ ├── bsr_sr/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── celeba256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── cin256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── ffhq256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── inpainting_big/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── layout2img-openimages256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── lsun_beds256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── lsun_churches256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── semantic_synthesis256/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── semantic_synthesis512/
│ │ │ │ │ └── config.yaml
│ │ │ │ └── text2img256/
│ │ │ │ └── config.yaml
│ │ │ ├── notebook_helpers.py
│ │ │ ├── scripts/
│ │ │ │ ├── download_first_stages.sh
│ │ │ │ ├── download_models.sh
│ │ │ │ ├── img2img.py
│ │ │ │ ├── inpaint.py
│ │ │ │ ├── knn2img.py
│ │ │ │ ├── latent_imagenet_diffusion.ipynb.REMOVED.git-id
│ │ │ │ ├── sample_diffusion.py
│ │ │ │ ├── tests/
│ │ │ │ │ └── test_watermark.py
│ │ │ │ ├── train_searcher.py
│ │ │ │ └── txt2img.py
│ │ │ └── setup.py
│ │ └── utils/
│ │ ├── deepspeed.py
│ │ ├── logger.py
│ │ └── utils.py
│ ├── edict/
│ │ ├── edict_functions.py
│ │ └── my_diffusers/
│ │ ├── __init__.py
│ │ ├── commands/
│ │ │ ├── __init__.py
│ │ │ ├── diffusers_cli.py
│ │ │ └── env.py
│ │ ├── configuration_utils.py
│ │ ├── dependency_versions_check.py
│ │ ├── dependency_versions_table.py
│ │ ├── dynamic_modules_utils.py
│ │ ├── hub_utils.py
│ │ ├── modeling_utils.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── embeddings.py
│ │ │ ├── resnet.py
│ │ │ ├── unet_2d.py
│ │ │ ├── unet_2d_condition.py
│ │ │ ├── unet_blocks.py
│ │ │ └── vae.py
│ │ ├── onnx_utils.py
│ │ ├── optimization.py
│ │ ├── pipeline_utils.py
│ │ ├── pipelines/
│ │ │ ├── __init__.py
│ │ │ ├── ddim/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_ddim.py
│ │ │ ├── ddpm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_ddpm.py
│ │ │ ├── latent_diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_latent_diffusion.py
│ │ │ ├── latent_diffusion_uncond/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_latent_diffusion_uncond.py
│ │ │ ├── pndm/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_pndm.py
│ │ │ ├── score_sde_ve/
│ │ │ │ ├── __init__.py
│ │ │ │ └── pipeline_score_sde_ve.py
│ │ │ ├── stable_diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── pipeline_stable_diffusion.py
│ │ │ │ ├── pipeline_stable_diffusion_img2img.py
│ │ │ │ ├── pipeline_stable_diffusion_inpaint.py
│ │ │ │ ├── pipeline_stable_diffusion_onnx.py
│ │ │ │ └── safety_checker.py
│ │ │ └── stochastic_karras_ve/
│ │ │ ├── __init__.py
│ │ │ └── pipeline_stochastic_karras_ve.py
│ │ ├── schedulers/
│ │ │ ├── __init__.py
│ │ │ ├── scheduling_ddim.py
│ │ │ ├── scheduling_ddpm.py
│ │ │ ├── scheduling_karras_ve.py
│ │ │ ├── scheduling_lms_discrete.py
│ │ │ ├── scheduling_pndm.py
│ │ │ ├── scheduling_sde_ve.py
│ │ │ ├── scheduling_sde_vp.py
│ │ │ └── scheduling_utils.py
│ │ ├── testing_utils.py
│ │ ├── training_utils.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── dummy_scipy_objects.py
│ │ ├── dummy_transformers_and_inflect_and_unidecode_objects.py
│ │ ├── dummy_transformers_and_onnx_objects.py
│ │ ├── dummy_transformers_objects.py
│ │ ├── import_utils.py
│ │ ├── logging.py
│ │ ├── model_card_template.md
│ │ └── outputs.py
│ ├── edit_friendly_ddm/
│ │ ├── inversion_utils.py
│ │ ├── ptp_classes.py
│ │ ├── ptp_utils.py
│ │ └── seq_aligner.py
│ ├── instructpix2pix/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── generate.yaml
│ │ │ └── train.yaml
│ │ ├── dataset_creation/
│ │ │ ├── generate_img_dataset.py
│ │ │ ├── generate_txt_dataset.py
│ │ │ ├── prepare_dataset.py
│ │ │ └── prepare_for_gpt.py
│ │ ├── edit_app.py
│ │ ├── edit_cli.py
│ │ ├── edit_dataset.py
│ │ ├── environment.yaml
│ │ ├── main.py
│ │ ├── metrics/
│ │ │ ├── clip_similarity.py
│ │ │ └── compute_metrics.py
│ │ ├── prompt_app.py
│ │ ├── scripts/
│ │ │ ├── download_checkpoints.sh
│ │ │ ├── download_data.sh
│ │ │ └── download_pretrained_sd.sh
│ │ └── stable_diffusion/
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── Stable_Diffusion_v1_Model_Card.md
│ │ ├── assets/
│ │ │ ├── results.gif.REMOVED.git-id
│ │ │ ├── stable-samples/
│ │ │ │ ├── img2img/
│ │ │ │ │ ├── upscaling-in.png.REMOVED.git-id
│ │ │ │ │ └── upscaling-out.png.REMOVED.git-id
│ │ │ │ └── txt2img/
│ │ │ │ ├── merged-0005.png.REMOVED.git-id
│ │ │ │ ├── merged-0006.png.REMOVED.git-id
│ │ │ │ └── merged-0007.png.REMOVED.git-id
│ │ │ └── txt2img-preview.png.REMOVED.git-id
│ │ ├── configs/
│ │ │ ├── autoencoder/
│ │ │ │ ├── autoencoder_kl_16x16x16.yaml
│ │ │ │ ├── autoencoder_kl_32x32x4.yaml
│ │ │ │ ├── autoencoder_kl_64x64x3.yaml
│ │ │ │ └── autoencoder_kl_8x8x64.yaml
│ │ │ ├── latent-diffusion/
│ │ │ │ ├── celebahq-ldm-vq-4.yaml
│ │ │ │ ├── cin-ldm-vq-f8.yaml
│ │ │ │ ├── cin256-v2.yaml
│ │ │ │ ├── ffhq-ldm-vq-4.yaml
│ │ │ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ │ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ │ │ └── txt2img-1p4B-eval.yaml
│ │ │ ├── retrieval-augmented-diffusion/
│ │ │ │ └── 768x768.yaml
│ │ │ └── stable-diffusion/
│ │ │ └── v1-inference.yaml
│ │ ├── environment.yaml
│ │ ├── ldm/
│ │ │ ├── lr_scheduler.py
│ │ │ ├── models/
│ │ │ │ ├── autoencoder.py
│ │ │ │ └── diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── classifier.py
│ │ │ │ ├── ddim.py
│ │ │ │ ├── ddpm.py
│ │ │ │ ├── ddpm_edit.py
│ │ │ │ ├── dpm_solver/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── dpm_solver.py
│ │ │ │ │ └── sampler.py
│ │ │ │ └── plms.py
│ │ │ ├── modules/
│ │ │ │ ├── attention.py
│ │ │ │ ├── diffusionmodules/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── model.py
│ │ │ │ │ ├── openaimodel.py
│ │ │ │ │ └── util.py
│ │ │ │ ├── distributions/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── distributions.py
│ │ │ │ ├── ema.py
│ │ │ │ ├── encoders/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── modules.py
│ │ │ │ ├── image_degradation/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── bsrgan.py
│ │ │ │ │ ├── bsrgan_light.py
│ │ │ │ │ └── utils_image.py
│ │ │ │ ├── losses/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── contperceptual.py
│ │ │ │ │ └── vqperceptual.py
│ │ │ │ └── x_transformer.py
│ │ │ └── util.py
│ │ ├── main.py
│ │ ├── models/
│ │ │ ├── first_stage_models/
│ │ │ │ ├── kl-f16/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── kl-f32/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── kl-f4/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── kl-f8/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f16/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f4/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f4-noattn/
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── vq-f8/
│ │ │ │ │ └── config.yaml
│ │ │ │ └── vq-f8-n256/
│ │ │ │ └── config.yaml
│ │ │ └── ldm/
│ │ │ ├── bsr_sr/
│ │ │ │ └── config.yaml
│ │ │ ├── celeba256/
│ │ │ │ └── config.yaml
│ │ │ ├── cin256/
│ │ │ │ └── config.yaml
│ │ │ ├── ffhq256/
│ │ │ │ └── config.yaml
│ │ │ ├── inpainting_big/
│ │ │ │ └── config.yaml
│ │ │ ├── layout2img-openimages256/
│ │ │ │ └── config.yaml
│ │ │ ├── lsun_beds256/
│ │ │ │ └── config.yaml
│ │ │ ├── lsun_churches256/
│ │ │ │ └── config.yaml
│ │ │ ├── semantic_synthesis256/
│ │ │ │ └── config.yaml
│ │ │ ├── semantic_synthesis512/
│ │ │ │ └── config.yaml
│ │ │ └── text2img256/
│ │ │ └── config.yaml
│ │ ├── notebook_helpers.py
│ │ ├── scripts/
│ │ │ ├── download_first_stages.sh
│ │ │ ├── download_models.sh
│ │ │ ├── img2img.py
│ │ │ ├── inpaint.py
│ │ │ ├── knn2img.py
│ │ │ ├── latent_imagenet_diffusion.ipynb.REMOVED.git-id
│ │ │ ├── sample_diffusion.py
│ │ │ ├── tests/
│ │ │ │ └── test_watermark.py
│ │ │ ├── train_searcher.py
│ │ │ └── txt2img.py
│ │ └── setup.py
│ ├── masactrl/
│ │ ├── diffuser_utils.py
│ │ ├── masactrl.py
│ │ └── masactrl_utils.py
│ ├── p2p/
│ │ ├── attention_control.py
│ │ ├── inversion.py
│ │ ├── p2p_guidance_forward.py
│ │ ├── proximal_guidance_forward.py
│ │ ├── scheduler_dev.py
│ │ └── seq_aligner.py
│ ├── p2p_editor.py
│ ├── pix2pix_zero/
│ │ ├── base_pipeline.py
│ │ ├── cross_attention.py
│ │ ├── ddim_inv.py
│ │ ├── edit_directions.py
│ │ ├── edit_pipeline.py
│ │ └── scheduler.py
│ └── stylediffusion/
│ ├── clip_util.py
│ ├── global_var.py
│ ├── inversion.py
│ ├── ptp_utils_v.py
│ ├── seq_aligner.py
│ └── utils.py
├── run_editing_blended_latent_diffusion.py
├── run_editing_edict.py
├── run_editing_edit_friendly_p2p.py
├── run_editing_instructdiffusion.py
├── run_editing_instructpix2pix.py
├── run_editing_masactrl.py
├── run_editing_p2p.py
├── run_editing_p2p_one_image.ipynb
├── run_editing_p2p_one_image.py
├── run_editing_pix2pix_zero.py
├── run_editing_pnp.py
├── run_editing_stylediffusion.py
└── utils/
└── utils.py
Showing preview only (218K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (2669 symbols across 182 files)
FILE: evaluation/evaluate.py
function mask_decode (line 9) | def mask_decode(encoded_mask,image_shape=[512,512]):
function calculate_metric (line 29) | def calculate_metric(metrics_calculator,metric, src_image, tgt_image, sr...
FILE: evaluation/matrics_calculator.py
class VitExtractor (line 12) | class VitExtractor:
method __init__ (line 19) | def __init__(self, model_name, device):
method _init_hooks_data (line 32) | def _init_hooks_data(self):
method _register_hooks (line 41) | def _register_hooks(self, **kwargs):
method _clear_hooks (line 52) | def _clear_hooks(self):
method _get_block_hook (line 57) | def _get_block_hook(self):
method _get_attn_hook (line 63) | def _get_attn_hook(self):
method _get_qkv_hook (line 69) | def _get_qkv_hook(self):
method _get_patch_imd_hook (line 76) | def _get_patch_imd_hook(self):
method get_feature_from_input (line 82) | def get_feature_from_input(self, input_img): # List([B, N, D])
method get_qkv_feature_from_input (line 90) | def get_qkv_feature_from_input(self, input_img):
method get_attn_feature_from_input (line 98) | def get_attn_feature_from_input(self, input_img):
method get_patch_size (line 106) | def get_patch_size(self):
method get_width_patch_num (line 109) | def get_width_patch_num(self, input_img_shape):
method get_height_patch_num (line 114) | def get_height_patch_num(self, input_img_shape):
method get_patch_num (line 119) | def get_patch_num(self, input_img_shape):
method get_head_num (line 123) | def get_head_num(self):
method get_embedding_dim (line 128) | def get_embedding_dim(self):
method get_queries_from_qkv (line 133) | def get_queries_from_qkv(self, qkv, input_img_shape):
method get_keys_from_qkv (line 140) | def get_keys_from_qkv(self, qkv, input_img_shape):
method get_values_from_qkv (line 147) | def get_values_from_qkv(self, qkv, input_img_shape):
method get_keys_from_input (line 154) | def get_keys_from_input(self, input_img, layer_num):
method get_keys_self_sim_from_input (line 159) | def get_keys_self_sim_from_input(self, input_img, layer_num):
method attn_cosine_sim (line 166) | def attn_cosine_sim(self,x, eps=1e-08):
class LossG (line 174) | class LossG(torch.nn.Module):
method __init__ (line 175) | def __init__(self, cfg,device):
method update_lambda_config (line 197) | def update_lambda_config(self, step):
method forward (line 209) | def forward(self, outputs, inputs):
method calculate_global_ssim_loss (line 237) | def calculate_global_ssim_loss(self, outputs, inputs):
method calculate_crop_cls_loss (line 248) | def calculate_crop_cls_loss(self, outputs, inputs):
method calculate_global_id_loss (line 259) | def calculate_global_id_loss(self, outputs, inputs):
class MetricsCalculator (line 271) | class MetricsCalculator:
method __init__ (line 272) | def __init__(self, device) -> None:
method calculate_clip_similarity (line 290) | def calculate_clip_similarity(self, img, txt, mask=None):
method calculate_psnr (line 304) | def calculate_psnr(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
method calculate_lpips (line 324) | def calculate_lpips(self, img_pred, img_gt, mask_pred=None, mask_gt=No...
method calculate_mse (line 344) | def calculate_mse(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
method calculate_ssim (line 364) | def calculate_ssim(self, img_pred, img_gt, mask_pred=None, mask_gt=None):
method calculate_structure_distance (line 385) | def calculate_structure_distance(self, img_pred, img_gt, mask_pred=Non...
FILE: models/InstructDiffusion/dataset/editing/edit_zip_dataset.py
class FilteredIP2PDataset (line 35) | class FilteredIP2PDataset(Dataset):
method __init__ (line 36) | def __init__(
method __len__ (line 115) | def __len__(self) -> int:
method __getitem__ (line 118) | def __getitem__(self, i: int) -> dict[str, Any]:
class GIERDataset (line 169) | class GIERDataset(Dataset):
method __init__ (line 170) | def __init__(
method __len__ (line 260) | def __len__(self) -> int:
method __getitem__ (line 263) | def __getitem__(self, i: int) -> dict[str, Any]:
class GQAInpaintDataset (line 296) | class GQAInpaintDataset(Dataset):
method __init__ (line 317) | def __init__(self, **kwargs):
method __len__ (line 330) | def __len__(self):
method __getitem__ (line 333) | def __getitem__(self, i):
class MagicBrushDataset (line 359) | class MagicBrushDataset(Dataset):
method __init__ (line 360) | def __init__(
method __len__ (line 389) | def __len__(self) -> int:
method __getitem__ (line 392) | def __getitem__(self, i: int) -> dict[str, Any]:
class IEIWDataset (line 426) | class IEIWDataset(Dataset):
method __init__ (line 427) | def __init__(
method __len__ (line 456) | def __len__(self) -> int:
method __getitem__ (line 459) | def __getitem__(self, i: int) -> dict[str, Any]:
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_clwd.py
function is_image_file (line 20) | def is_image_file(filename):
class CLWD (line 24) | class CLWD(Dataset):
method __init__ (line 25) | def __init__(self, path, split="train", size=256, interpolation="pil_l...
method __len__ (line 66) | def __len__(self):
method __getitem__ (line 69) | def __getitem__(self, index):
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_gopro.py
function is_image_file (line 20) | def is_image_file(filename):
class GoPro (line 24) | class GoPro(Dataset):
method __init__ (line 25) | def __init__(self, path, split="train", size=256, interpolation="pil_l...
method __len__ (line 66) | def __len__(self):
method __getitem__ (line 69) | def __getitem__(self, index):
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_reds.py
function is_image_file (line 20) | def is_image_file(filename):
class REDS (line 24) | class REDS(Dataset):
method __init__ (line 25) | def __init__(self, path, split="train", size=256, interpolation="pil_l...
method __len__ (line 71) | def __len__(self):
method __getitem__ (line 74) | def __getitem__(self, index):
FILE: models/InstructDiffusion/dataset/low_level/lowlevel_sidd.py
function is_image_file (line 20) | def is_image_file(filename):
class SIDD (line 24) | class SIDD(Dataset):
method __init__ (line 25) | def __init__(self, path, split="train", size=256, interpolation="pil_l...
method __len__ (line 65) | def __len__(self):
method __getitem__ (line 68) | def __getitem__(self, index):
FILE: models/InstructDiffusion/dataset/pose/pose.py
function readTXT (line 45) | def readTXT(txt_path):
class PoseDataset (line 52) | class PoseDataset(Dataset):
method __init__ (line 53) | def __init__(self, root, image_set, is_train, max_prompt_num=5, min_pr...
method _get_db (line 93) | def _get_db(self):
method evaluate (line 96) | def evaluate(self, preds, output_dir, *args, **kwargs):
method half_body_transform (line 99) | def half_body_transform(self, joints, joints_vis):
method __len__ (line 144) | def __len__(self,):
method __getitem__ (line 147) | def __getitem__(self, idx):
method generate_target (line 221) | def generate_target(self, input, joints, joints_vis):
class COCODataset (line 281) | class COCODataset(PoseDataset):
method __init__ (line 282) | def __init__(self, root, image_set, is_train, max_prompt_num=5, min_pr...
method _get_ann_file_keypoint (line 347) | def _get_ann_file_keypoint(self):
method _load_image_set_index (line 374) | def _load_image_set_index(self):
method _get_db (line 379) | def _get_db(self):
method _load_coco_keypoint_annotations (line 383) | def _load_coco_keypoint_annotations(self):
method _load_coco_keypoint_annotation_kernal (line 390) | def _load_coco_keypoint_annotation_kernal(self, index):
method _box2cs (line 459) | def _box2cs(self, box):
method _xywh2cs (line 463) | def _xywh2cs(self, x, y, w, h):
method image_path_from_index (line 480) | def image_path_from_index(self, index, im_ann):
function flip_back (line 511) | def flip_back(output_flipped, matched_parts):
function fliplr_joints (line 528) | def fliplr_joints(joints, joints_vis, width, matched_parts):
function get_affine_transform (line 545) | def get_affine_transform(
function affine_transform (line 580) | def affine_transform(pt, t):
function get_3rd_point (line 586) | def get_3rd_point(a, b):
function get_dir (line 591) | def get_dir(src_point, rot_rad):
class CrowdPoseDataset (line 601) | class CrowdPoseDataset(COCODataset):
method __init__ (line 602) | def __init__(self, root, image_set, is_train, max_prompt_num=5, min_pr...
class AICDataset (line 637) | class AICDataset(COCODataset):
method __init__ (line 638) | def __init__(self, root, image_set, is_train, max_prompt_num=5, min_pr...
class MPIIDataset (line 672) | class MPIIDataset(PoseDataset):
method __init__ (line 673) | def __init__(self, root, image_set, is_train, max_prompt_num=5, min_pr...
method _get_db (line 709) | def _get_db(self):
FILE: models/InstructDiffusion/dataset/seg/coco_stuff.py
class COCOStuffDataset (line 27) | class COCOStuffDataset(Dataset):
method __init__ (line 28) | def __init__(
method __len__ (line 92) | def __len__(self) -> int:
method _augmentation_new (line 96) | def _augmentation_new(self, image, label):
method __getitem__ (line 118) | def __getitem__(self, i):
FILE: models/InstructDiffusion/dataset/seg/grefcoco.py
class G_REFER (line 34) | class G_REFER:
method __init__ (line 36) | def __init__(self, data_root, dataset='grefcoco', splitBy='unc'):
method _toList (line 74) | def _toList(x):
method match_any (line 78) | def match_any(a, b):
method createIndex (line 83) | def createIndex(self):
method getRefIds (line 163) | def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
method getAnnIds (line 185) | def getAnnIds(self, image_ids=[], ref_ids=[]):
method getImgIds (line 205) | def getImgIds(self, ref_ids=[]):
method getCatIds (line 214) | def getCatIds(self):
method loadRefs (line 217) | def loadRefs(self, ref_ids=[]):
method loadAnns (line 220) | def loadAnns(self, ann_ids=[]):
method loadImgs (line 225) | def loadImgs(self, image_ids=[]):
method loadCats (line 228) | def loadCats(self, cat_ids=[]):
method getRefBox (line 231) | def getRefBox(self, ref_id):
method showRef (line 235) | def showRef(self, ref, seg_box='seg'):
method getMask (line 278) | def getMask(self, ann):
method getMaskByRef (line 296) | def getMaskByRef(self, ref=None, ref_id=None, merge=False):
method showMask (line 325) | def showMask(self, ref):
FILE: models/InstructDiffusion/dataset/seg/grefcoco_segmentation.py
class GrefCOCODataset (line 27) | class GrefCOCODataset(Dataset):
method __init__ (line 28) | def __init__(
method __len__ (line 73) | def __len__(self) -> int:
method _augmentation_new (line 76) | def _augmentation_new(self, image, label):
method __getitem__ (line 98) | def __getitem__(self, i: int) -> dict[str, Any]:
FILE: models/InstructDiffusion/dataset/seg/refcoco.py
class REFER (line 44) | class REFER:
method __init__ (line 46) | def __init__(self, data_root, dataset='refcoco', splitBy='unc'):
method createIndex (line 79) | def createIndex(self):
method getRefIds (line 142) | def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
method getAnnIds (line 173) | def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
method getImgIds (line 193) | def getImgIds(self, ref_ids=[]):
method getCatIds (line 202) | def getCatIds(self):
method loadRefs (line 205) | def loadRefs(self, ref_ids=[]):
method loadAnns (line 211) | def loadAnns(self, ann_ids=[]):
method loadImgs (line 217) | def loadImgs(self, image_ids=[]):
method loadCats (line 223) | def loadCats(self, cat_ids=[]):
method getRefBox (line 229) | def getRefBox(self, ref_id):
method showRef (line 234) | def showRef(self, ref, seg_box='seg'):
method getMask (line 277) | def getMask(self, ref):
method showMask (line 327) | def showMask(self, ref):
FILE: models/InstructDiffusion/dataset/seg/refcoco_segmentation.py
class RefCOCODataset (line 27) | class RefCOCODataset(Dataset):
method __init__ (line 28) | def __init__(
method __len__ (line 73) | def __len__(self) -> int:
method _augmentation_new (line 76) | def _augmentation_new(self, image, label):
method __getitem__ (line 98) | def __getitem__(self, i: int) -> dict[str, Any]:
FILE: models/InstructDiffusion/dataset/utils/zip_manager.py
function norm (line 21) | def norm(path):
function in_blob (line 25) | def in_blob(file):
function map_name (line 31) | def map_name(file):
function preload (line 39) | def preload(db,sync=False):
function get_keys_from_lmdb (line 46) | def get_keys_from_lmdb(db):
function decode_img (line 50) | def decode_img(byteflow):
function decode_text (line 59) | def decode_text(byteflow):
class ZipManager (line 68) | class ZipManager:
method __init__ (line 69) | def __init__(self, zip_path,data_type,prefix=None) -> None:
method deinitialze (line 75) | def deinitialze(self):
method initialize (line 80) | def initialize(self,close=True):
method keys (line 89) | def keys(self):
method get (line 94) | def get(self, name):
class MultipleZipManager (line 101) | class MultipleZipManager:
method __init__ (line 102) | def __init__(self, files: list, data_type, sync=True):
method initialize (line 115) | def initialize(self):
method keys (line 136) | def keys(self):
method get (line 141) | def get(self, name):
FILE: models/InstructDiffusion/edit_app.py
function load_model_from_config (line 36) | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
function read_content (line 49) | def read_content(file_path: str) -> str:
function get_header (line 58) | def get_header():
class CFGDenoiser (line 85) | class CFGDenoiser(nn.Module):
method __init__ (line 86) | def __init__(self, model):
method forward (line 90) | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_sc...
function predict (line 101) | def predict(
function main (line 219) | def main(ckpt="checkpoints/v1-5-pruned-emaonly-adaption-task-humanalign....
FILE: models/InstructDiffusion/edit_cli.py
class CFGDenoiser (line 31) | class CFGDenoiser(nn.Module):
method __init__ (line 32) | def __init__(self, model):
method forward (line 36) | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_sc...
function load_model_from_config (line 50) | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
function main (line 63) | def main():
FILE: models/InstructDiffusion/main.py
function wandb_log (line 38) | def wandb_log(*args, **kwargs):
function get_parser (line 43) | def get_parser(**parser_kwargs):
class WrappedDataset (line 156) | class WrappedDataset(Dataset):
method __init__ (line 159) | def __init__(self, dataset):
method __len__ (line 162) | def __len__(self):
method __getitem__ (line 165) | def __getitem__(self, idx):
class DataModuleFromConfig (line 169) | class DataModuleFromConfig():
method __init__ (line 170) | def __init__(self, batch_size, train=None, validation=None, test=None,...
method prepare_data (line 199) | def prepare_data(self):
method setup (line 203) | def setup(self, stage=None):
method _train_concat_dataloader (line 211) | def _train_concat_dataloader(self):
method _train_dataloader (line 230) | def _train_dataloader(self):
method _val_dataloader (line 243) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 254) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 267) | def _predict_dataloader(self, shuffle=False):
function train_one_epoch (line 276) | def train_one_epoch(config, model, model_ema, data_loader, val_data_load...
FILE: models/InstructDiffusion/stable_diffusion/ldm/lr_scheduler.py
class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
method schedule (line 17) | def schedule(self, n, **kwargs):
method __call__ (line 32) | def __call__(self, n, **kwargs):
class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
method find_in_interval (line 52) | def find_in_interval(self, n):
method schedule (line 59) | def schedule(self, n, **kwargs):
method __call__ (line 77) | def __call__(self, n, **kwargs):
class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
method schedule (line 83) | def schedule(self, n, **kwargs):
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/autoencoder.py
class VQModel (line 20) | class VQModel(nn.Module):
method __init__ (line 21) | def __init__(self,
method ema_scope (line 70) | def ema_scope(self, context=None):
method init_from_ckpt (line 84) | def init_from_ckpt(self, path, ignore_keys=list()):
method on_train_batch_end (line 98) | def on_train_batch_end(self, *args, **kwargs):
method encode (line 102) | def encode(self, x):
method encode_to_prequant (line 108) | def encode_to_prequant(self, x):
method decode (line 113) | def decode(self, quant):
method decode_code (line 118) | def decode_code(self, code_b):
method forward (line 123) | def forward(self, input, return_pred_indices=False):
method get_input (line 130) | def get_input(self, batch, k):
method training_step (line 148) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 170) | def validation_step(self, batch, batch_idx):
method _validation_step (line 176) | def _validation_step(self, batch, batch_idx, suffix=""):
method configure_optimizers (line 203) | def configure_optimizers(self):
method get_last_layer (line 236) | def get_last_layer(self):
method log_images (line 239) | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
method to_rgb (line 261) | def to_rgb(self, x):
class VQModelInterface (line 270) | class VQModelInterface(VQModel):
method __init__ (line 271) | def __init__(self, embed_dim, *args, **kwargs):
method encode (line 275) | def encode(self, x):
method decode (line 280) | def decode(self, h, force_not_quantize=False):
class AutoencoderKL (line 291) | class AutoencoderKL(nn.Module):
method __init__ (line 292) | def __init__(self,
method init_from_ckpt (line 319) | def init_from_ckpt(self, path, ignore_keys=list()):
method encode (line 330) | def encode(self, x):
method decode (line 336) | def decode(self, z):
method forward (line 341) | def forward(self, input, sample_posterior=True):
method get_input (line 350) | def get_input(self, batch, k):
method training_step (line 357) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 378) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 392) | def configure_optimizers(self):
method get_last_layer (line 403) | def get_last_layer(self):
method log_images (line 407) | def log_images(self, batch, only_inputs=False, **kwargs):
method to_rgb (line 423) | def to_rgb(self, x):
class IdentityFirstStage (line 432) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 433) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 437) | def encode(self, x, *args, **kwargs):
method decode (line 440) | def decode(self, x, *args, **kwargs):
method quantize (line 443) | def quantize(self, x, *args, **kwargs):
method forward (line 448) | def forward(self, x, *args, **kwargs):
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/classifier.py
function disabled_train (line 22) | def disabled_train(self, mode=True):
class NoisyLatentImageClassifier (line 28) | class NoisyLatentImageClassifier(pl.LightningModule):
method __init__ (line 30) | def __init__(self,
method init_from_ckpt (line 70) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method load_diffusion (line 88) | def load_diffusion(self):
method load_classifier (line 95) | def load_classifier(self, ckpt_path, pool):
method get_x_noisy (line 110) | def get_x_noisy(self, x, t, noise=None):
method forward (line 120) | def forward(self, x_noisy, t, *args, **kwargs):
method get_input (line 124) | def get_input(self, batch, k):
method get_conditioning (line 133) | def get_conditioning(self, batch, k=None):
method compute_top_k (line 150) | def compute_top_k(self, logits, labels, k, reduction="mean"):
method on_train_epoch_start (line 157) | def on_train_epoch_start(self):
method write_logs (line 162) | def write_logs(self, loss, logits, targets):
method shared_step (line 179) | def shared_step(self, batch, t=None):
method training_step (line 198) | def training_step(self, batch, batch_idx):
method reset_noise_accs (line 202) | def reset_noise_accs(self):
method on_validation_start (line 206) | def on_validation_start(self):
method validation_step (line 210) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 220) | def configure_optimizers(self):
method log_images (line 238) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/ddim.py
class DDIMSampler (line 12) | class DDIMSampler(object):
method __init__ (line 13) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 19) | def register_buffer(self, name, attr):
method make_schedule (line 25) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 57) | def sample(self,
method ddim_sampling (line 114) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 166) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method stochastic_encode (line 207) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 223) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/ddpm.py
function disabled_train (line 34) | def disabled_train(self, mode=True):
function uniform_on_device (line 40) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 44) | class DDPM(pl.LightningModule):
method __init__ (line 46) | def __init__(self,
method register_schedule (line 117) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 172) | def ema_scope(self, context=None):
method init_from_ckpt (line 186) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 204) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 216) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 222) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 231) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 244) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 253) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 268) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 274) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 279) | def get_loss(self, pred, target, mean=True):
method p_losses (line 294) | def p_losses(self, x_start, t, noise=None):
method forward (line 323) | def forward(self, x, *args, **kwargs):
method get_input (line 329) | def get_input(self, batch, k):
method shared_step (line 337) | def shared_step(self, batch):
method training_step (line 342) | def training_step(self, batch, batch_idx):
method validation_step (line 358) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 366) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 370) | def _get_rows_from_list(self, samples):
method log_images (line 378) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 415) | def configure_optimizers(self):
class LatentDiffusion (line 424) | class LatentDiffusion(DDPM):
method __init__ (line 426) | def __init__(self,
method make_cond_schedule (line 471) | def make_cond_schedule(self, ):
method on_train_batch_start (line 478) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 493) | def register_schedule(self,
method instantiate_first_stage (line 502) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 509) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 530) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 542) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 551) | def get_learned_conditioning(self, c):
method meshgrid (line 564) | def meshgrid(self, h, w):
method delta_border (line 571) | def delta_border(self, h, w):
method get_weighting (line 585) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 601) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 654) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 706) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method differentiable_decode_first_stage (line 766) | def differentiable_decode_first_stage(self, z, predict_cids=False, for...
method encode_first_stage (line 826) | def encode_first_stage(self, x):
method shared_step (line 865) | def shared_step(self, batch, **kwargs):
method forward (line 870) | def forward(self, x, c, *args, **kwargs):
method _rescale_annotations (line 881) | def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: mov...
method apply_model (line 891) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 994) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 998) | def _prior_bpd(self, x_start):
method p_losses (line 1012) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 1047) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1079) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1110) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1166) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1217) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1235) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
method log_images (line 1251) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 1361) | def configure_optimizers(self):
method to_rgb (line 1386) | def to_rgb(self, x):
class DiffusionWrapper (line 1395) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1396) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1402) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
class Layout2ImgDiffusion (line 1424) | class Layout2ImgDiffusion(LatentDiffusion):
method __init__ (line 1426) | def __init__(self, cond_stage_key, *args, **kwargs):
method log_images (line 1430) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/ddpm_edit.py
function disabled_train (line 37) | def disabled_train(self, mode=True):
function uniform_on_device (line 43) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 47) | class DDPM(nn.Module):
method __init__ (line 49) | def __init__(self,
method register_schedule (line 117) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method init_from_ckpt (line 171) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 214) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 226) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 232) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 241) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 254) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 263) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 278) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 284) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 289) | def get_loss(self, pred, target, mean=True):
method p_losses (line 305) | def p_losses(self, x_start, t, noise=None):
method forward (line 334) | def forward(self, x, *args, **kwargs):
method get_input (line 340) | def get_input(self, batch, k):
class NNParams (line 344) | class NNParams(nn.Module):
method __init__ (line 345) | def __init__(self, dim):
method forward (line 350) | def forward(self):
class LatentDiffusion (line 354) | class LatentDiffusion(DDPM):
method __init__ (line 356) | def __init__(self,
method make_cond_schedule (line 405) | def make_cond_schedule(self, ):
method on_train_batch_start (line 412) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 426) | def register_schedule(self,
method instantiate_first_stage (line 435) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 442) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 463) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 475) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 484) | def get_learned_conditioning(self, c):
method meshgrid (line 497) | def meshgrid(self, h, w):
method delta_border (line 504) | def delta_border(self, h, w):
method get_weighting (line 518) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 534) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 587) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 618) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method differentiable_decode_first_stage (line 678) | def differentiable_decode_first_stage(self, z, predict_cids=False, for...
method encode_first_stage (line 738) | def encode_first_stage(self, x):
method forward (line 777) | def forward(self, batch, batch_idx, num_steps, *args, **kwargs):
method _rescale_annotations (line 791) | def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: mov...
method apply_model (line 801) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 904) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 908) | def _prior_bpd(self, x_start):
method p_losses (line 922) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 974) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1006) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1037) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1093) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1144) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1162) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
class DiffusionWrapper (line 1177) | class DiffusionWrapper(nn.Module):
method __init__ (line 1178) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1184) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
class Layout2ImgDiffusion (line 1206) | class Layout2ImgDiffusion(LatentDiffusion):
method __init__ (line 1208) | def __init__(self, cond_stage_key, *args, **kwargs):
method log_images (line 1212) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py
class NoiseScheduleVP (line 6) | class NoiseScheduleVP:
method __init__ (line 7) | def __init__(
method marginal_log_mean_coeff (line 125) | def marginal_log_mean_coeff(self, t):
method marginal_alpha (line 138) | def marginal_alpha(self, t):
method marginal_std (line 144) | def marginal_std(self, t):
method marginal_lambda (line 150) | def marginal_lambda(self, t):
method inverse_lambda (line 158) | def inverse_lambda(self, lamb):
function model_wrapper (line 177) | def model_wrapper(
class DPM_Solver (line 351) | class DPM_Solver:
method __init__ (line 352) | def __init__(self, model_fn, noise_schedule, predict_x0=False, thresho...
method noise_prediction_fn (line 380) | def noise_prediction_fn(self, x, t):
method data_prediction_fn (line 386) | def data_prediction_fn(self, x, t):
method model_fn (line 401) | def model_fn(self, x, t):
method get_time_steps (line 410) | def get_time_steps(self, skip_type, t_T, t_0, N, device):
method get_orders_and_timesteps_for_singlestep_solver (line 439) | def get_orders_and_timesteps_for_singlestep_solver(self, steps, order,...
method denoise_to_zero_fn (line 498) | def denoise_to_zero_fn(self, x, s):
method dpm_solver_first_update (line 504) | def dpm_solver_first_update(self, x, s, t, model_s=None, return_interm...
method singlestep_dpm_solver_second_update (line 551) | def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s...
method singlestep_dpm_solver_third_update (line 633) | def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./...
method multistep_dpm_solver_second_update (line 755) | def multistep_dpm_solver_second_update(self, x, model_prev_list, t_pre...
method multistep_dpm_solver_third_update (line 812) | def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev...
method singlestep_dpm_solver_update (line 859) | def singlestep_dpm_solver_update(self, x, s, t, order, return_intermed...
method multistep_dpm_solver_update (line 885) | def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list,...
method dpm_solver_adaptive (line 909) | def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0....
method sample (line 965) | def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_...
function interpolate_fn (line 1132) | def interpolate_fn(x, xp, yp):
function expand_dims (line 1174) | def expand_dims(v, dims):
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/dpm_solver/sampler.py
class DPMSolverSampler (line 8) | class DPMSolverSampler(object):
method __init__ (line 9) | def __init__(self, model, **kwargs):
method register_buffer (line 15) | def register_buffer(self, name, attr):
method sample (line 22) | def sample(self,
FILE: models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/plms.py
class PLMSSampler (line 11) | class PLMSSampler(object):
method __init__ (line 12) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 18) | def register_buffer(self, name, attr):
method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 58) | def sample(self,
method plms_sampling (line 115) | def plms_sampling(self, cond, shape,
method p_sample_plms (line 173) | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_origin...
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/attention.py
function exists (line 14) | def exists(val):
function uniq (line 18) | def uniq(arr):
function default (line 22) | def default(val, d):
function max_neg_value (line 28) | def max_neg_value(t):
function init_ (line 32) | def init_(tensor):
class GEGLU (line 40) | class GEGLU(nn.Module):
method __init__ (line 41) | def __init__(self, dim_in, dim_out):
method forward (line 45) | def forward(self, x):
class FeedForward (line 50) | class FeedForward(nn.Module):
method __init__ (line 51) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 66) | def forward(self, x):
function zero_module (line 70) | def zero_module(module):
function Normalize (line 79) | def Normalize(in_channels, default_eps):
class LinearAttention (line 86) | class LinearAttention(nn.Module):
method __init__ (line 87) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 94) | def forward(self, x):
class SpatialSelfAttention (line 106) | class SpatialSelfAttention(nn.Module):
method __init__ (line 107) | def __init__(self, in_channels):
method forward (line 133) | def forward(self, x):
class CrossAttention (line 160) | class CrossAttention(nn.Module):
method __init__ (line 161) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 180) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 242) | class BasicTransformerBlock(nn.Module):
method __init__ (line 246) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 263) | def forward(self, x, context=None):
method _forward (line 266) | def _forward(self, x, context=None):
class SpatialTransformer (line 330) | class SpatialTransformer(nn.Module):
method __init__ (line 339) | def __init__(self, in_channels, n_heads, d_head, default_eps, force_ty...
method forward (line 374) | def forward(self, x, context=None):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 14) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 35) | def nonlinearity(x):
function Normalize (line 40) | def Normalize(in_channels, default_eps, num_groups=32):
class Upsample (line 47) | class Upsample(nn.Module):
method __init__ (line 48) | def __init__(self, in_channels, with_conv):
method forward (line 58) | def forward(self, x):
class Downsample (line 65) | class Downsample(nn.Module):
method __init__ (line 66) | def __init__(self, in_channels, with_conv):
method forward (line 77) | def forward(self, x):
class ResnetBlock (line 87) | class ResnetBlock(nn.Module):
method __init__ (line 88) | def __init__(self, *, in_channels, default_eps, force_type_convert, ou...
method forward (line 127) | def forward(self, x, temb):
class LinAttnBlock (line 158) | class LinAttnBlock(LinearAttention):
method __init__ (line 160) | def __init__(self, in_channels):
class AttnBlock (line 164) | class AttnBlock(nn.Module):
method __init__ (line 165) | def __init__(self, in_channels, default_eps, force_type_convert):
method forward (line 192) | def forward(self, x):
function make_attn (line 223) | def make_attn(in_channels, default_eps, force_type_convert, attn_type="v...
class Model (line 234) | class Model(nn.Module):
method __init__ (line 235) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 334) | def forward(self, x, t=None, context=None):
method get_last_layer (line 382) | def get_last_layer(self):
class Encoder (line 386) | class Encoder(nn.Module):
method __init__ (line 387) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 459) | def forward(self, x):
class Decoder (line 491) | class Decoder(nn.Module):
method __init__ (line 492) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 571) | def forward(self, z):
class SimpleDecoder (line 611) | class SimpleDecoder(nn.Module):
method __init__ (line 612) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 634) | def forward(self, x):
class UpsampleDecoder (line 647) | class UpsampleDecoder(nn.Module):
method __init__ (line 648) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 681) | def forward(self, x):
class LatentRescaler (line 695) | class LatentRescaler(nn.Module):
method __init__ (line 696) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 720) | def forward(self, x):
class MergedRescaleEncoder (line 732) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 733) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 745) | def forward(self, x):
class MergedRescaleDecoder (line 751) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 752) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 762) | def forward(self, x):
class Upsampler (line 768) | class Upsampler(nn.Module):
method __init__ (line 769) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 781) | def forward(self, x):
class Resize (line 787) | class Resize(nn.Module):
method __init__ (line 788) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 803) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 810) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 812) | def __init__(self, ch_mult:list, in_channels,
method instantiate_pretrained (line 847) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 856) | def encode_with_pretrained(self,x):
method forward (line 862) | def forward(self,x):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 30) | def convert_module_to_f16(l):
function convert_module_to_f32 (line 40) | def convert_module_to_f32(l):
function convert_some_linear_to_f16 (line 50) | def convert_some_linear_to_f16(l):
function convert_some_linear_to_f32 (line 60) | def convert_some_linear_to_f32(l):
class PositionEmbedding (line 70) | class PositionEmbedding(nn.Module):
method __init__ (line 71) | def __init__(self, embed_dim, spacial_dim):
method forward (line 74) | def forward(self):
class AttentionPool2d (line 79) | class AttentionPool2d(nn.Module):
method __init__ (line 84) | def __init__(
method forward (line 98) | def forward(self, x):
class TimestepBlock (line 109) | class TimestepBlock(nn.Module):
method forward (line 115) | def forward(self, x, emb):
class TimestepEmbedSequential (line 121) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 127) | def forward(self, x, emb, context=None):
class Upsample (line 144) | class Upsample(nn.Module):
method __init__ (line 153) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 162) | def forward(self, x):
class TransposedUpsample (line 174) | class TransposedUpsample(nn.Module):
method __init__ (line 176) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 183) | def forward(self,x):
class Downsample (line 187) | class Downsample(nn.Module):
method __init__ (line 196) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
method forward (line 211) | def forward(self, x):
class ResBlock (line 216) | class ResBlock(TimestepBlock):
method __init__ (line 232) | def __init__(
method forward (line 296) | def forward(self, x, emb):
method _forward (line 311) | def _forward(self, x, emb):
class AttentionBlock (line 335) | class AttentionBlock(nn.Module):
method __init__ (line 342) | def __init__(
method forward (line 371) | def forward(self, x):
method _forward (line 375) | def _forward(self, x):
function count_flops_attn (line 384) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 404) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 409) | def __init__(self, n_heads):
method forward (line 413) | def forward(self, qkv):
method count_flops (line 432) | def count_flops(model, _x, y):
class QKVAttention (line 436) | class QKVAttention(nn.Module):
method __init__ (line 441) | def __init__(self, n_heads):
method forward (line 445) | def forward(self, qkv):
method count_flops (line 466) | def count_flops(model, _x, y):
class UNetModel (line 470) | class UNetModel(nn.Module):
method __init__ (line 500) | def __init__(
method convert_to_fp16 (line 761) | def convert_to_fp16(self):
method convert_to_fp32 (line 773) | def convert_to_fp32(self):
method forward (line 785) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
class EncoderUNetModel (line 820) | class EncoderUNetModel(nn.Module):
method __init__ (line 826) | def __init__(
method convert_to_fp16 (line 1002) | def convert_to_fp16(self):
method convert_to_fp32 (line 1013) | def convert_to_fp32(self):
method forward (line 1024) | def forward(self, x, timesteps):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 22) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 47) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 64) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 78) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 97) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 103) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 136) | class CheckpointFunction(torch.autograd.Function):
method forward (line 139) | def forward(ctx, run_function, length, *args):
method backward (line 150) | def backward(ctx, *output_grads):
function timestep_embedding (line 170) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 193) | def zero_module(module):
function scale_module (line 202) | def scale_module(module, scale):
function mean_flat (line 211) | def mean_flat(tensor):
function normalization (line 218) | def normalization(channels):
class SiLU (line 228) | class SiLU(nn.Module):
method forward (line 229) | def forward(self, x):
class GroupNorm32 (line 233) | class GroupNorm32(nn.GroupNorm):
method forward (line 234) | def forward(self, x):
function conv_nd (line 237) | def conv_nd(dims, *args, **kwargs):
function linear (line 250) | def linear(*args, **kwargs):
function avg_pool_nd (line 257) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 270) | class HybridConditioner(nn.Module):
method __init__ (line 272) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 277) | def forward(self, c_concat, c_crossattn):
function noise_like (line 283) | def noise_like(shape, device, repeat=False):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self):
method kl (line 39) | def kl(self, other=None):
method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
method mode (line 61) | def mode(self):
function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/ema.py
class LitEma (line 11) | class LitEma(nn.Module):
method __init__ (line 12) | def __init__(self, model, decay=0.9999, decay_resume=0.9999, use_num_u...
method forward (line 32) | def forward(self, model):
method copy_to (line 55) | def copy_to(self, model, test=False):
method store (line 67) | def store(self, parameters):
method restore (line 77) | def restore(self, parameters):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/encoders/modules.py
class AbstractEncoder (line 12) | class AbstractEncoder(nn.Module):
method __init__ (line 13) | def __init__(self):
method encode (line 16) | def encode(self, *args, **kwargs):
class ClassEmbedder (line 21) | class ClassEmbedder(nn.Module):
method __init__ (line 22) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 27) | def forward(self, batch, key=None):
class TransformerEmbedder (line 36) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 38) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 44) | def forward(self, tokens):
method encode (line 49) | def encode(self, x):
class BERTTokenizer (line 53) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 55) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 63) | def forward(self, text):
method encode (line 70) | def encode(self, text):
method decode (line 76) | def decode(self, text):
class BERTEmbedder (line 80) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 82) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 93) | def forward(self, text):
method encode (line 101) | def encode(self, text):
class SpatialRescaler (line 106) | class SpatialRescaler(nn.Module):
method __init__ (line 107) | def __init__(self,
method forward (line 125) | def forward(self,x):
method encode (line 134) | def encode(self, x):
class FrozenCLIPEmbedder (line 137) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 139) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 147) | def freeze(self):
method forward (line 152) | def forward(self, text):
method encode (line 161) | def encode(self, text):
class FrozenCLIPTextEmbedder (line 165) | class FrozenCLIPTextEmbedder(nn.Module):
method __init__ (line 169) | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n...
method freeze (line 177) | def freeze(self):
method forward (line 182) | def forward(self, text):
method encode (line 189) | def encode(self, text):
class FrozenClipImageEmbedder (line 197) | class FrozenClipImageEmbedder(nn.Module):
method __init__ (line 201) | def __init__(
method preprocess (line 216) | def preprocess(self, x):
method forward (line 226) | def forward(self, x):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 339) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 343) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 373) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 390) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 408) | def add_Poisson_noise(img):
function add_JPEG_noise (line 422) | def add_JPEG_noise(img):
function random_crop (line 431) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 442) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 534) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/utils_image.py
function is_image_file (line 29) | def is_image_file(filename):
function get_timestamp (line 33) | def get_timestamp():
function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
function get_image_paths (line 67) | def get_image_paths(dataroot):
function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
function imssave (line 112) | def imssave(imgs, img_path):
function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
function mkdir (line 153) | def mkdir(path):
function mkdirs (line 158) | def mkdirs(paths):
function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
function imread_uint (line 185) | def imread_uint(path, n_channels=3):
function imsave (line 203) | def imsave(img, img_path):
function imwrite (line 209) | def imwrite(img, img_path):
function read_img (line 220) | def read_img(path):
function uint2single (line 249) | def uint2single(img):
function single2uint (line 254) | def single2uint(img):
function uint162single (line 259) | def uint162single(img):
function single2uint16 (line 264) | def single2uint16(img):
function uint2tensor4 (line 275) | def uint2tensor4(img):
function uint2tensor3 (line 282) | def uint2tensor3(img):
function tensor2uint (line 289) | def tensor2uint(img):
function single2tensor3 (line 302) | def single2tensor3(img):
function single2tensor4 (line 307) | def single2tensor4(img):
function tensor2single (line 312) | def tensor2single(img):
function tensor2single3 (line 320) | def tensor2single3(img):
function single2tensor5 (line 329) | def single2tensor5(img):
function single32tensor5 (line 333) | def single32tensor5(img):
function single42tensor4 (line 337) | def single42tensor4(img):
function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
function augment_img (line 380) | def augment_img(img, mode=0):
function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
function modcrop (line 494) | def modcrop(img_in, scale):
function shave (line 510) | def shave(img_in, border=0):
function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
function ssim (line 669) | def ssim(img1, img2):
function cubic (line 700) | def cubic(x):
function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 766) | def imresize(img, scale, antialiasing=True):
function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/losses/contperceptual.py
class LPIPSWithDiscriminator (line 7) | class LPIPSWithDiscriminator(nn.Module):
method __init__ (line 8) | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixello...
method calculate_adaptive_weight (line 32) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 45) | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/losses/vqperceptual.py
function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
function l1 (line 35) | def l1(x, y):
function l2 (line 39) | def l2(x, y):
class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
FILE: models/InstructDiffusion/stable_diffusion/ldm/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 25) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 26) | def __init__(self, dim, max_seq_len):
method init_ (line 31) | def init_(self):
method forward (line 34) | def forward(self, x):
class FixedPositionalEmbedding (line 39) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 40) | def __init__(self, dim):
method forward (line 45) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 54) | def exists(val):
function default (line 58) | def default(val, d):
function always (line 64) | def always(val):
function not_equals (line 70) | def not_equals(val):
function equals (line 76) | def equals(val):
function max_neg_value (line 82) | def max_neg_value(tensor):
function pick_and_pop (line 88) | def pick_and_pop(keys, d):
function group_dict_by_key (line 93) | def group_dict_by_key(cond, d):
function string_begins_with (line 102) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 106) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 110) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 117) | class Scale(nn.Module):
method __init__ (line 118) | def __init__(self, value, fn):
method forward (line 123) | def forward(self, x, **kwargs):
class Rezero (line 128) | class Rezero(nn.Module):
method __init__ (line 129) | def __init__(self, fn):
method forward (line 134) | def forward(self, x, **kwargs):
class ScaleNorm (line 139) | class ScaleNorm(nn.Module):
method __init__ (line 140) | def __init__(self, dim, eps=1e-5):
method forward (line 146) | def forward(self, x):
class RMSNorm (line 151) | class RMSNorm(nn.Module):
method __init__ (line 152) | def __init__(self, dim, eps=1e-8):
method forward (line 158) | def forward(self, x):
class Residual (line 163) | class Residual(nn.Module):
method forward (line 164) | def forward(self, x, residual):
class GRUGating (line 168) | class GRUGating(nn.Module):
method __init__ (line 169) | def __init__(self, dim):
method forward (line 173) | def forward(self, x, residual):
class GEGLU (line 184) | class GEGLU(nn.Module):
method __init__ (line 185) | def __init__(self, dim_in, dim_out):
method forward (line 189) | def forward(self, x):
class FeedForward (line 194) | class FeedForward(nn.Module):
method __init__ (line 195) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 210) | def forward(self, x):
class Attention (line 215) | class Attention(nn.Module):
method __init__ (line 216) | def __init__(
method forward (line 268) | def forward(
class AttentionLayers (line 370) | class AttentionLayers(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 481) | def forward(
class Encoder (line 541) | class Encoder(AttentionLayers):
method __init__ (line 542) | def __init__(self, **kwargs):
class TransformerWrapper (line 548) | class TransformerWrapper(nn.Module):
method __init__ (line 549) | def __init__(
method init_ (line 595) | def init_(self):
method forward (line 598) | def forward(
FILE: models/InstructDiffusion/stable_diffusion/ldm/util.py
function log_txt_as_img (line 17) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 41) | def ismap(x):
function isimage (line 47) | def isimage(x):
function exists (line 53) | def exists(x):
function default (line 57) | def default(val, d):
function mean_flat (line 63) | def mean_flat(tensor):
function count_params (line 71) | def count_params(model, verbose=False):
function instantiate_from_config (line 78) | def instantiate_from_config(config):
function get_obj_from_str (line 88) | def get_obj_from_str(string, reload=False):
function _do_parallel_data_prefetch (line 96) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
function parallel_data_prefetch (line 108) | def parallel_data_prefetch(
FILE: models/InstructDiffusion/stable_diffusion/main.py
function get_parser (line 24) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 126) | def nondefault_trainer_args(opt):
class WrappedDataset (line 133) | class WrappedDataset(Dataset):
method __init__ (line 136) | def __init__(self, dataset):
method __len__ (line 139) | def __len__(self):
method __getitem__ (line 142) | def __getitem__(self, idx):
function worker_init_fn (line 146) | def worker_init_fn(_):
class DataModuleFromConfig (line 162) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 163) | def __init__(self, batch_size, train=None, validation=None, test=None,...
method prepare_data (line 185) | def prepare_data(self):
method setup (line 189) | def setup(self, stage=None):
method _train_dataloader (line 197) | def _train_dataloader(self):
method _val_dataloader (line 207) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 218) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 231) | def _predict_dataloader(self, shuffle=False):
class SetupCallback (line 240) | class SetupCallback(Callback):
method __init__ (line 241) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_keyboard_interrupt (line 251) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_pretrain_routine_start (line 257) | def on_pretrain_routine_start(self, trainer, pl_module):
class ImageLogger (line 289) | class ImageLogger(Callback):
method __init__ (line 290) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
method _testtube (line 310) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 321) | def log_local(self, save_dir, split, images,
method log_img (line 340) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 372) | def check_frequency(self, check_idx):
method on_train_batch_end (line 383) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 387) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
class CUDACallback (line 395) | class CUDACallback(Callback):
method on_train_epoch_start (line 397) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 403) | def on_train_epoch_end(self, trainer, pl_module, outputs):
function melk (line 697) | def melk(*args, **kwargs):
function divein (line 705) | def divein(*args, **kwargs):
FILE: models/InstructDiffusion/stable_diffusion/notebook_helpers.py
function download_models (line 19) | def download_models(mode):
function load_model_from_config (line 40) | def load_model_from_config(config, ckpt):
function get_model (line 52) | def get_model(mode):
function get_custom_cond (line 59) | def get_custom_cond(mode):
function get_cond_options (line 85) | def get_cond_options(mode):
function select_cond_path (line 92) | def select_cond_path(mode):
function get_cond (line 107) | def get_cond(mode, selected_path):
function visualize_cond_img (line 127) | def visualize_cond_img(path):
function run (line 131) | def run(model, selected_path, task, custom_steps, resize_enabled=False, ...
function convsample_ddim (line 188) | def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, n...
function make_convolutional_sample (line 208) | def make_convolutional_sample(batch, model, mode="vanilla", custom_steps...
FILE: models/InstructDiffusion/stable_diffusion/scripts/img2img.py
function chunk (line 23) | def chunk(it, size):
function load_model_from_config (line 28) | def load_model_from_config(config, ckpt, verbose=False):
function load_img (line 48) | def load_img(path):
function main (line 60) | def main():
FILE: models/InstructDiffusion/stable_diffusion/scripts/inpaint.py
function make_batch (line 11) | def make_batch(image, mask, device):
FILE: models/InstructDiffusion/stable_diffusion/scripts/knn2img.py
function chunk (line 36) | def chunk(it, size):
function load_model_from_config (line 41) | def load_model_from_config(config, ckpt, verbose=False):
class Searcher (line 61) | class Searcher(object):
method __init__ (line 62) | def __init__(self, database, retriever_version='ViT-L/14'):
method train_searcher (line 75) | def train_searcher(self, k,
method load_single_file (line 91) | def load_single_file(self, saved_embeddings):
method load_multi_files (line 96) | def load_multi_files(self, data_archive):
method load_database (line 104) | def load_database(self):
method load_retriever (line 123) | def load_retriever(self, version='ViT-L/14', ):
method load_searcher (line 130) | def load_searcher(self):
method search (line 135) | def search(self, x, k):
method __call__ (line 163) | def __call__(self, x, n):
FILE: models/InstructDiffusion/stable_diffusion/scripts/sample_diffusion.py
function custom_to_pil (line 15) | def custom_to_pil(x):
function custom_to_np (line 27) | def custom_to_np(x):
function logs2pil (line 36) | def logs2pil(logs, keys=["sample"]):
function convsample (line 54) | def convsample(model, shape, return_intermediates=True,
function convsample_ddim (line 69) | def convsample_ddim(model, steps, shape, eta=1.0
function make_convolutional_sample (line 79) | def make_convolutional_sample(model, batch_size, vanilla=False, custom_s...
function run (line 108) | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, ...
function save_logs (line 143) | def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
function get_parser (line 162) | def get_parser():
function load_model_from_config (line 220) | def load_model_from_config(config, sd):
function load_model (line 228) | def load_model(config, ckpt, gpu, eval_mode):
FILE: models/InstructDiffusion/stable_diffusion/scripts/tests/test_watermark.py
function testit (line 6) | def testit(img_path):
FILE: models/InstructDiffusion/stable_diffusion/scripts/train_searcher.py
function search_bruteforce (line 12) | def search_bruteforce(searcher):
function search_partioned_ah (line 16) | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
function search_ah (line 24) | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
function load_datapool (line 28) | def load_datapool(dpath):
function train_searcher (line 62) | def train_searcher(opt,
FILE: models/InstructDiffusion/stable_diffusion/scripts/txt2img.py
function chunk (line 32) | def chunk(it, size):
function numpy_to_pil (line 37) | def numpy_to_pil(images):
function load_model_from_config (line 49) | def load_model_from_config(config, ckpt, verbose=False):
function put_watermark (line 69) | def put_watermark(img, wm_encoder=None):
function load_replacement (line 77) | def load_replacement(x):
function check_safety (line 88) | def check_safety(x_image):
function main (line 98) | def main():
FILE: models/InstructDiffusion/utils/deepspeed.py
function create_ds_config (line 7) | def create_ds_config(args, config, cfgdir):
FILE: models/InstructDiffusion/utils/logger.py
function create_logger (line 16) | def create_logger(output_dir, dist_rank=0, name=''):
FILE: models/InstructDiffusion/utils/utils.py
function load_checkpoint (line 14) | def load_checkpoint(file_name, config, model, model_ema, optimizer, lr_s...
function save_checkpoint (line 55) | def save_checkpoint(ckptdir, config, epoch, model, model_ema, max_accura...
function get_grad_norm (line 90) | def get_grad_norm(parameters, norm_type=2):
function auto_resume_helper (line 103) | def auto_resume_helper(config, output_dir):
function reduce_tensor (line 128) | def reduce_tensor(tensor):
function ampscaler_get_grad_norm (line 135) | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch...
class NativeScalerWithGradNormCount (line 151) | class NativeScalerWithGradNormCount:
method __init__ (line 154) | def __init__(self):
method __call__ (line 157) | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, c...
method state_dict (line 173) | def state_dict(self):
method load_state_dict (line 176) | def load_state_dict(self, state_dict):
FILE: models/edict/edict_functions.py
function EDICT_editing (line 56) | def EDICT_editing(im_path,
function recon_test (line 118) | def recon_test(im, steps=50, strength=1.0,
function im_to_np (line 170) | def im_to_np(im):
function img2img_editing (line 173) | def img2img_editing(im_path,
function center_crop (line 190) | def center_crop(im):
function load_im_into_format_from_path (line 203) | def load_im_into_format_from_path(im_path):
function init_attention_weights (line 208) | def init_attention_weights(weight_tuples):
function init_attention_edit (line 225) | def init_attention_edit(tokens, tokens_edit):
function init_attention_func (line 250) | def init_attention_func():
function use_last_tokens_attention (line 299) | def use_last_tokens_attention(use=True):
function use_last_tokens_attention_weights (line 305) | def use_last_tokens_attention_weights(use=True):
function use_last_self_attention (line 311) | def use_last_self_attention(use=True):
function save_last_tokens_attention (line 317) | def save_last_tokens_attention(save=True):
function save_last_self_attention (line 323) | def save_last_self_attention(save=True):
function baseline_stablediffusion (line 334) | def baseline_stablediffusion(prompt="",
function get_alpha_and_beta (line 599) | def get_alpha_and_beta(t, scheduler):
function forward_step (line 621) | def forward_step(
function reverse_step (line 653) | def reverse_step(
function latent_to_image (line 690) | def latent_to_image(latent):
function prep_image_for_return (line 695) | def prep_image_for_return(image):
function coupled_stablediffusion (line 708) | def coupled_stablediffusion(prompt="",
FILE: models/edict/my_diffusers/commands/__init__.py
class BaseDiffusersCLICommand (line 19) | class BaseDiffusersCLICommand(ABC):
method register_subcommand (line 22) | def register_subcommand(parser: ArgumentParser):
method run (line 26) | def run(self):
FILE: models/edict/my_diffusers/commands/diffusers_cli.py
function main (line 21) | def main():
FILE: models/edict/my_diffusers/commands/env.py
function info_command_factory (line 25) | def info_command_factory(_):
class EnvironmentCommand (line 29) | class EnvironmentCommand(BaseDiffusersCLICommand):
method register_subcommand (line 31) | def register_subcommand(parser: ArgumentParser):
method run (line 35) | def run(self):
method format_dict (line 69) | def format_dict(d):
FILE: models/edict/my_diffusers/configuration_utils.py
class ConfigMixin (line 38) | class ConfigMixin:
method register_to_config (line 54) | def register_to_config(self, **kwargs):
method save_config (line 76) | def save_config(self, save_directory: Union[str, os.PathLike], push_to...
method from_config (line 97) | def from_config(cls, pretrained_model_name_or_path: Union[str, os.Path...
method get_config_dict (line 168) | def get_config_dict(
method extract_init_dict (line 268) | def extract_init_dict(cls, config_dict, **kwargs):
method _dict_from_json_file (line 297) | def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
method __repr__ (line 302) | def __repr__(self):
method config (line 306) | def config(self) -> Dict[str, Any]:
method to_json_string (line 309) | def to_json_string(self) -> str:
method to_json_file (line 319) | def to_json_file(self, json_file_path: Union[str, os.PathLike]):
class FrozenDict (line 331) | class FrozenDict(OrderedDict):
method __init__ (line 332) | def __init__(self, *args, **kwargs):
method __delitem__ (line 340) | def __delitem__(self, *args, **kwargs):
method setdefault (line 343) | def setdefault(self, *args, **kwargs):
method pop (line 346) | def pop(self, *args, **kwargs):
method update (line 349) | def update(self, *args, **kwargs):
method __setattr__ (line 352) | def __setattr__(self, name, value):
method __setitem__ (line 357) | def __setitem__(self, name, value):
function register_to_config (line 363) | def register_to_config(init):
FILE: models/edict/my_diffusers/dependency_versions_check.py
function dep_version_check (line 46) | def dep_version_check(pkg, hint=None):
FILE: models/edict/my_diffusers/dynamic_modules_utils.py
function init_hf_modules (line 33) | def init_hf_modules():
function create_dynamic_module (line 48) | def create_dynamic_module(name: Union[str, os.PathLike]):
function get_relative_imports (line 63) | def get_relative_imports(module_file):
function get_relative_import_files (line 81) | def get_relative_import_files(module_file):
function check_imports (line 110) | def check_imports(filename):
function get_class_in_module (line 142) | def get_class_in_module(class_name, module_path):
function get_cached_module_file (line 151) | def get_cached_module_file(
function get_class_from_dynamic_module (line 249) | def get_class_from_dynamic_module(
FILE: models/edict/my_diffusers/hub_utils.py
function get_full_repo_name (line 38) | def get_full_repo_name(model_id: str, organization: Optional[str] = None...
function init_git_repo (line 48) | def init_git_repo(args, at_init: bool = False):
function push_to_hub (line 96) | def push_to_hub(
function create_model_card (line 152) | def create_model_card(args, model_name):
FILE: models/edict/my_diffusers/modeling_utils.py
function get_parameter_device (line 36) | def get_parameter_device(parameter: torch.nn.Module):
function get_parameter_dtype (line 51) | def get_parameter_dtype(parameter: torch.nn.Module):
function load_state_dict (line 66) | def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
function _load_state_dict_into_model (line 94) | def _load_state_dict_into_model(model_to_load, state_dict):
class ModelMixin (line 115) | class ModelMixin(torch.nn.Module):
method __init__ (line 128) | def __init__(self):
method save_pretrained (line 131) | def save_pretrained(
method from_pretrained (line 182) | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union...
method _load_pretrained_model (line 384) | def _load_pretrained_model(
method device (line 488) | def device(self) -> device:
method dtype (line 496) | def dtype(self) -> torch.dtype:
method num_parameters (line 502) | def num_parameters(self, only_trainable: bool = False, exclude_embeddi...
function unwrap_model (line 531) | def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
FILE: models/edict/my_diffusers/models/attention.py
class AttentionBlock (line 9) | class AttentionBlock(nn.Module):
method __init__ (line 25) | def __init__(
method transpose_for_scores (line 48) | def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
method forward (line 54) | def forward(self, hidden_states):
class SpatialTransformer (line 95) | class SpatialTransformer(nn.Module):
method __init__ (line 109) | def __init__(
method _set_attention_slice (line 136) | def _set_attention_slice(self, slice_size):
method forward (line 140) | def forward(self, x, context=None):
class BasicTransformerBlock (line 154) | class BasicTransformerBlock(nn.Module):
method __init__ (line 168) | def __init__(
method _set_attention_slice (line 191) | def _set_attention_slice(self, slice_size):
method forward (line 195) | def forward(self, x, context=None):
class CrossAttention (line 203) | class CrossAttention(nn.Module):
method __init__ (line 216) | def __init__(
method reshape_heads_to_batch_dim (line 236) | def reshape_heads_to_batch_dim(self, tensor):
method reshape_batch_dim_to_heads (line 243) | def reshape_batch_dim_to_heads(self, tensor):
method forward (line 250) | def forward(self, x, context=None, mask=None):
method _attention (line 269) | def _attention(self, query, key, value, sequence_length, dim):
class FeedForward (line 291) | class FeedForward(nn.Module):
method __init__ (line 303) | def __init__(
method forward (line 313) | def forward(self, x):
class GEGLU (line 318) | class GEGLU(nn.Module):
method __init__ (line 327) | def __init__(self, dim_in: int, dim_out: int):
method forward (line 331) | def forward(self, x):
FILE: models/edict/my_diffusers/models/embeddings.py
function get_timestep_embedding (line 21) | def get_timestep_embedding(
class TimestepEmbedding (line 63) | class TimestepEmbedding(nn.Module):
method __init__ (line 64) | def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "s...
method forward (line 73) | def forward(self, sample):
class Timesteps (line 83) | class Timesteps(nn.Module):
method __init__ (line 84) | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale...
method forward (line 90) | def forward(self, timesteps):
class GaussianFourierProjection (line 100) | class GaussianFourierProjection(nn.Module):
method __init__ (line 103) | def __init__(self, embedding_size: int = 256, scale: float = 1.0):
method forward (line 112) | def forward(self, x):
FILE: models/edict/my_diffusers/models/resnet.py
class Upsample2D (line 9) | class Upsample2D(nn.Module):
method __init__ (line 18) | def __init__(self, channels, use_conv=False, use_conv_transpose=False,...
method forward (line 38) | def forward(self, x):
class Downsample2D (line 55) | class Downsample2D(nn.Module):
method __init__ (line 64) | def __init__(self, channels, use_conv=False, out_channels=None, paddin...
method forward (line 88) | def forward(self, x):
class FirUpsample2D (line 100) | class FirUpsample2D(nn.Module):
method __init__ (line 101) | def __init__(self, channels=None, out_channels=None, use_conv=False, f...
method _upsample_2d (line 110) | def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
method forward (line 179) | def forward(self, x):
class FirDownsample2D (line 189) | class FirDownsample2D(nn.Module):
method __init__ (line 190) | def __init__(self, channels=None, out_channels=None, use_conv=False, f...
method _downsample_2d (line 199) | def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
method forward (line 241) | def forward(self, x):
class ResnetBlock2D (line 251) | class ResnetBlock2D(nn.Module):
method __init__ (line 252) | def __init__(
method forward (line 331) | def forward(self, x, temb):
class Mish (line 368) | class Mish(torch.nn.Module):
method forward (line 369) | def forward(self, x):
function upsample_2d (line 373) | def upsample_2d(x, kernel=None, factor=2, gain=1):
function downsample_2d (line 406) | def downsample_2d(x, kernel=None, factor=2, gain=1):
function upfirdn2d_native (line 438) | def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
FILE: models/edict/my_diffusers/models/unet_2d.py
class UNet2DOutput (line 15) | class UNet2DOutput(BaseOutput):
class UNet2DModel (line 25) | class UNet2DModel(ModelMixin, ConfigMixin):
method __init__ (line 59) | def __init__(
method forward (line 165) | def forward(
FILE: models/edict/my_diffusers/models/unet_2d_condition.py
class UNet2DConditionOutput (line 15) | class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel (line 25) | class UNet2DConditionModel(ModelMixin, ConfigMixin):
method __init__ (line 58) | def __init__(
method set_attention_slice (line 167) | def set_attention_slice(self, slice_size):
method forward (line 189) | def forward(
FILE: models/edict/my_diffusers/models/unet_blocks.py
function get_down_block (line 24) | def get_down_block(
function get_up_block (line 111) | def get_up_block(
class UNetMidBlock2D (line 198) | class UNetMidBlock2D(nn.Module):
method __init__ (line 199) | def __init__(
method forward (line 265) | def forward(self, hidden_states, temb=None, encoder_states=None):
class UNetMidBlock2DCrossAttn (line 277) | class UNetMidBlock2DCrossAttn(nn.Module):
method __init__ (line 278) | def __init__(
method set_attention_slice (line 346) | def set_attention_slice(self, slice_size):
method forward (line 361) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
class AttnDownBlock2D (line 370) | class AttnDownBlock2D(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 434) | def forward(self, hidden_states, temb=None):
class CrossAttnDownBlock2D (line 451) | class CrossAttnDownBlock2D(nn.Module):
method __init__ (line 452) | def __init__(
method set_attention_slice (line 517) | def set_attention_slice(self, slice_size):
method forward (line 532) | def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
class DownBlock2D (line 549) | class DownBlock2D(nn.Module):
method __init__ (line 550) | def __init__(
method forward (line 599) | def forward(self, hidden_states, temb=None):
class DownEncoderBlock2D (line 615) | class DownEncoderBlock2D(nn.Module):
method __init__ (line 616) | def __init__(
method forward (line 664) | def forward(self, hidden_states):
class AttnDownEncoderBlock2D (line 675) | class AttnDownEncoderBlock2D(nn.Module):
method __init__ (line 676) | def __init__(
method forward (line 736) | def forward(self, hidden_states):
class AttnSkipDownBlock2D (line 748) | class AttnSkipDownBlock2D(nn.Module):
method __init__ (line 749) | def __init__(
method forward (line 821) | def forward(self, hidden_states, temb=None, skip_sample=None):
class SkipDownBlock2D (line 841) | class SkipDownBlock2D(nn.Module):
method __init__ (line 842) | def __init__(
method forward (line 901) | def forward(self, hidden_states, temb=None, skip_sample=None):
class AttnUpBlock2D (line 920) | class AttnUpBlock2D(nn.Module):
method __init__ (line 921) | def __init__(
method forward (line 980) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
class CrossAttnUpBlock2D (line 998) | class CrossAttnUpBlock2D(nn.Module):
method __init__ (line 999) | def __init__(
method set_attention_slice (line 1061) | def set_attention_slice(self, slice_size):
method forward (line 1076) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, e...
class UpBlock2D (line 1094) | class UpBlock2D(nn.Module):
method __init__ (line 1095) | def __init__(
method forward (line 1140) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
class UpDecoderBlock2D (line 1157) | class UpDecoderBlock2D(nn.Module):
method __init__ (line 1158) | def __init__(
method forward (line 1200) | def forward(self, hidden_states):
class AttnUpDecoderBlock2D (line 1211) | class AttnUpDecoderBlock2D(nn.Module):
method __init__ (line 1212) | def __init__(
method forward (line 1266) | def forward(self, hidden_states):
class AttnSkipUpBlock2D (line 1278) | class AttnSkipUpBlock2D(nn.Module):
method __init__ (line 1279) | def __init__(
method forward (line 1361) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, s...
class SkipUpBlock2D (line 1389) | class SkipUpBlock2D(nn.Module):
method __init__ (line 1390) | def __init__(
method forward (line 1458) | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, s...
FILE: models/edict/my_diffusers/models/vae.py
class DecoderOutput (line 15) | class DecoderOutput(BaseOutput):
class VQEncoderOutput (line 28) | class VQEncoderOutput(BaseOutput):
class AutoencoderKLOutput (line 41) | class AutoencoderKLOutput(BaseOutput):
class Encoder (line 54) | class Encoder(nn.Module):
method __init__ (line 55) | def __init__(
method forward (line 114) | def forward(self, x):
class Decoder (line 133) | class Decoder(nn.Module):
method __init__ (line 134) | def __init__(
method forward (line 193) | def forward(self, z):
class VectorQuantizer (line 212) | class VectorQuantizer(nn.Module):
method __init__ (line 221) | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random...
method remap_to_used (line 248) | def remap_to_used(self, inds):
method unmap_to_all (line 262) | def unmap_to_all(self, inds):
method forward (line 272) | def forward(self, z):
method get_codebook_entry (line 311) | def get_codebook_entry(self, indices, shape):
class DiagonalGaussianDistribution (line 329) | class DiagonalGaussianDistribution(object):
method __init__ (line 330) | def __init__(self, parameters, deterministic=False):
method sample (line 340) | def sample(self, generator: Optional[torch.Generator] = None) -> torch...
method kl (line 347) | def kl(self, other=None):
method nll (line 363) | def nll(self, sample, dims=[1, 2, 3]):
method mode (line 369) | def mode(self):
class VQModel (line 373) | class VQModel(ModelMixin, ConfigMixin):
method __init__ (line 396) | def __init__(
method encode (line 438) | def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQ...
method decode (line 447) | def decode(
method forward (line 463) | def forward(self, sample: torch.FloatTensor, return_dict: bool = True)...
class AutoencoderKL (line 480) | class AutoencoderKL(ModelMixin, ConfigMixin):
method __init__ (line 502) | def __init__(
method encode (line 540) | def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Au...
method decode (line 550) | def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Un...
method forward (line 559) | def forward(
FILE: models/edict/my_diffusers/onnx_utils.py
class OnnxRuntimeModel (line 40) | class OnnxRuntimeModel:
method __init__ (line 43) | def __init__(self, model=None, **kwargs):
method __call__ (line 49) | def __call__(self, **kwargs):
method load_model (line 54) | def load_model(path: Union[str, Path], provider=None):
method _save_pretrained (line 70) | def _save_pretrained(self, save_directory: Union[str, Path], file_name...
method save_pretrained (line 90) | def save_pretrained(
method _from_pretrained (line 113) | def _from_pretrained(
method from_pretrained (line 170) | def from_pretrained(
FILE: models/edict/my_diffusers/optimization.py
class SchedulerType (line 30) | class SchedulerType(Enum):
function get_constant_schedule (line 39) | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
function get_constant_schedule_with_warmup (line 55) | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_s...
function get_linear_schedule_with_warmup (line 80) | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_tra...
function get_cosine_schedule_with_warmup (line 109) | def get_cosine_schedule_with_warmup(
function get_cosine_with_hard_restarts_schedule_with_warmup (line 143) | def get_cosine_with_hard_restarts_schedule_with_warmup(
function get_polynomial_decay_schedule_with_warmup (line 178) | def get_polynomial_decay_schedule_with_warmup(
function get_scheduler (line 238) | def get_scheduler(
FILE: models/edict/my_diffusers/pipeline_utils.py
class ImagePipelineOutput (line 63) | class ImagePipelineOutput(BaseOutput):
class DiffusionPipeline (line 76) | class DiffusionPipeline(ConfigMixin):
method register_modules (line 93) | def register_modules(self, **kwargs):
method save_pretrained (line 123) | def save_pretrained(self, save_directory: Union[str, os.PathLike]):
method to (line 160) | def to(self, torch_device: Optional[Union[str, torch.device]] = None):
method device (line 172) | def device(self) -> torch.device:
method from_pretrained (line 185) | def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union...
method numpy_to_pil (line 395) | def numpy_to_pil(images):
method progress_bar (line 406) | def progress_bar(self, iterable):
method set_progress_bar_config (line 416) | def set_progress_bar_config(self, **kwargs):
FILE: models/edict/my_diffusers/pipelines/ddim/pipeline_ddim.py
class DDIMPipeline (line 25) | class DDIMPipeline(DiffusionPipeline):
method __init__ (line 37) | def __init__(self, unet, scheduler):
method __call__ (line 43) | def __call__(
FILE: models/edict/my_diffusers/pipelines/ddpm/pipeline_ddpm.py
class DDPMPipeline (line 25) | class DDPMPipeline(DiffusionPipeline):
method __init__ (line 37) | def __init__(self, unet, scheduler):
method __call__ (line 43) | def __call__(
FILE: models/edict/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
class LDMTextToImagePipeline (line 21) | class LDMTextToImagePipeline(DiffusionPipeline):
method __init__ (line 40) | def __init__(
method __call__ (line 53) | def __call__(
class LDMBertConfig (line 202) | class LDMBertConfig(PretrainedConfig):
method __init__ (line 207) | def __init__(
function _expand_mask (line 249) | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Option...
class LDMBertAttention (line 264) | class LDMBertAttention(nn.Module):
method __init__ (line 267) | def __init__(
method _shape (line 291) | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
method forward (line 294) | def forward(
class LDMBertEncoderLayer (line 408) | class LDMBertEncoderLayer(nn.Module):
method __init__ (line 409) | def __init__(self, config: LDMBertConfig):
method forward (line 426) | def forward(
class LDMBertPreTrainedModel (line 478) | class LDMBertPreTrainedModel(PreTrainedModel):
method _init_weights (line 484) | def _init_weights(self, module):
method _set_gradient_checkpointing (line 495) | def _set_gradient_checkpointing(self, module, value=False):
method dummy_inputs (line 500) | def dummy_inputs(self):
class LDMBertEncoder (line 510) | class LDMBertEncoder(LDMBertPreTrainedModel):
method __init__ (line 520) | def __init__(self, config: LDMBertConfig):
method get_input_embeddings (line 538) | def get_input_embeddings(self):
method set_input_embeddings (line 541) | def set_input_embeddings(self, value):
method forward (line 544) | def forward(
class LDMBertModel (line 677) | class LDMBertModel(LDMBertPreTrainedModel):
method __init__ (line 678) | def __init__(self, config: LDMBertConfig):
method forward (line 683) | def forward(
FILE: models/edict/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
class LDMPipeline (line 12) | class LDMPipeline(DiffusionPipeline):
method __init__ (line 25) | def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMS...
method __call__ (line 31) | def __call__(
FILE: models/edict/my_diffusers/pipelines/pndm/pipeline_pndm.py
class PNDMPipeline (line 27) | class PNDMPipeline(DiffusionPipeline):
method __init__ (line 41) | def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
method __call__ (line 47) | def __call__(
FILE: models/edict/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
class ScoreSdeVePipeline (line 12) | class ScoreSdeVePipeline(DiffusionPipeline):
method __init__ (line 23) | def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
method __call__ (line 28) | def __call__(
FILE: models/edict/my_diffusers/pipelines/stable_diffusion/__init__.py
class StableDiffusionPipelineOutput (line 13) | class StableDiffusionPipelineOutput(BaseOutput):
FILE: models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
class StableDiffusionPipeline (line 16) | class StableDiffusionPipeline(DiffusionPipeline):
method __init__ (line 44) | def __init__(
method enable_attention_slicing (line 66) | def enable_attention_slicing(self, slice_size: Optional[Union[str, int...
method disable_attention_slicing (line 85) | def disable_attention_slicing(self):
method __call__ (line 94) | def __call__(
FILE: models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
function preprocess (line 17) | def preprocess(image):
class StableDiffusionImg2ImgPipeline (line 27) | class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
method __init__ (line 55) | def __init__(
method enable_attention_slicing (line 77) | def enable_attention_slicing(self, slice_size: Optional[Union[str, int...
method disable_attention_slicing (line 96) | def disable_attention_slicing(self):
method __call__ (line 105) | def __call__(
FILE: models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
function preprocess_image (line 22) | def preprocess_image(image):
function preprocess_mask (line 32) | def preprocess_mask(mask):
class StableDiffusionInpaintPipeline (line 45) | class StableDiffusionInpaintPipeline(DiffusionPipeline):
method __init__ (line 73) | def __init__(
method enable_attention_slicing (line 96) | def enable_attention_slicing(self, slice_size: Optional[Union[str, int...
method disable_attention_slicing (line 115) | def disable_attention_slicing(self):
method __call__ (line 124) | def __call__(
FILE: models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
class StableDiffusionOnnxPipeline (line 14) | class StableDiffusionOnnxPipeline(DiffusionPipeline):
method __init__ (line 23) | def __init__(
method __call__ (line 45) | def __call__(
FILE: models/edict/my_diffusers/pipelines/stable_diffusion/safety_checker.py
function cosine_distance (line 13) | def cosine_distance(image_embeds, text_embeds):
class StableDiffusionSafetyChecker (line 19) | class StableDiffusionSafetyChecker(PreTrainedModel):
method __init__ (line 22) | def __init__(self, config: CLIPConfig):
method forward (line 35) | def forward(self, clip_input, images):
method forward_onnx (line 83) | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.Fl...
FILE: models/edict/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
class KarrasVePipeline (line 12) | class KarrasVePipeline(DiffusionPipeline):
method __init__ (line 31) | def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
method __call__ (line 37) | def __call__(
FILE: models/edict/my_diffusers/schedulers/scheduling_ddim.py
function betas_for_alpha_bar (line 28) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDIMScheduler (line 57) | class DDIMScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 87) | def __init__(
method _get_variance (line 131) | def _get_variance(self, timestep, prev_timestep):
method set_timesteps (line 141) | def set_timesteps(self, num_inference_steps: int, offset: int = 0):
method step (line 163) | def step(
method add_noise (line 255) | def add_noise(
method __len__ (line 269) | def __len__(self):
FILE: models/edict/my_diffusers/schedulers/scheduling_ddpm.py
function betas_for_alpha_bar (line 27) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDPMScheduler (line 56) | class DDPMScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 86) | def __init__(
method set_timesteps (line 124) | def set_timesteps(self, num_inference_steps: int):
method _get_variance (line 139) | def _get_variance(self, t, predicted_variance=None, variance_type=None):
method step (line 172) | def step(
method add_noise (line 248) | def add_noise(
method __len__ (line 263) | def __len__(self):
FILE: models/edict/my_diffusers/schedulers/scheduling_karras_ve.py
class KarrasVeOutput (line 28) | class KarrasVeOutput(BaseOutput):
class KarrasVeScheduler (line 44) | class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 78) | def __init__(
method set_timesteps (line 96) | def set_timesteps(self, num_inference_steps: int):
method add_noise_to_input (line 115) | def add_noise_to_input(
method step (line 136) | def step(
method step_correct (line 172) | def step_correct(
method add_noise (line 207) | def add_noise(self, original_samples, noise, timesteps):
FILE: models/edict/my_diffusers/schedulers/scheduling_lms_discrete.py
class LMSDiscreteScheduler (line 26) | class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 53) | def __init__(
method get_lms_coefficient (line 86) | def get_lms_coefficient(self, order, t, current_order):
method set_timesteps (line 108) | def set_timesteps(self, num_inference_steps: int):
method step (line 130) | def step(
method add_noise (line 181) | def add_noise(
method __len__ (line 192) | def __len__(self):
FILE: models/edict/my_diffusers/schedulers/scheduling_pndm.py
function betas_for_alpha_bar (line 27) | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class PNDMScheduler (line 56) | class PNDMScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 84) | def __init__(
method set_timesteps (line 134) | def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> ...
method step (line 173) | def step(
method step_prk (line 204) | def step_prk(
method step_plms (line 259) | def step_plms(
method _get_prev_sample (line 325) | def _get_prev_sample(self, sample, timestep, timestep_prev, model_outp...
method add_noise (line 361) | def add_noise(
method __len__ (line 377) | def __len__(self):
FILE: models/edict/my_diffusers/schedulers/scheduling_sde_ve.py
class SdeVeOutput (line 30) | class SdeVeOutput(BaseOutput):
class ScoreSdeVeScheduler (line 46) | class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 71) | def __init__(
method set_timesteps (line 89) | def set_timesteps(self, num_inference_steps: int, sampling_eps: float ...
method set_sigmas (line 108) | def set_sigmas(
method get_adjacent_sigma (line 141) | def get_adjacent_sigma(self, timesteps, t):
method set_seed (line 154) | def set_seed(self, seed):
method step_pred (line 168) | def step_pred(
method step_correct (line 230) | def step_correct(
method __len__ (line 282) | def __len__(self):
FILE: models/edict/my_diffusers/schedulers/scheduling_sde_vp.py
class ScoreSdeVpScheduler (line 26) | class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
method __init__ (line 42) | def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20...
method set_timesteps (line 48) | def set_timesteps(self, num_inference_steps):
method step_pred (line 51) | def step_pred(self, score, x, t):
method __len__ (line 80) | def __len__(self):
FILE: models/edict/my_diffusers/schedulers/scheduling_utils.py
class SchedulerOutput (line 27) | class SchedulerOutput(BaseOutput):
class SchedulerMixin (line 40) | class SchedulerMixin:
method set_format (line 48) | def set_format(self, tensor_format="pt"):
method clip (line 57) | def clip(self, tensor, min_value=None, max_value=None):
method log (line 67) | def log(self, tensor):
method match_shape (line 77) | def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadca...
method norm (line 99) | def norm(self, tensor):
method randn_like (line 108) | def randn_like(self, tensor, generator=None):
method zeros_like (line 118) | def zeros_like(self, tensor):
FILE: models/edict/my_diffusers/testing_utils.py
function parse_flag_from_env (line 19) | def parse_flag_from_env(key, default=False):
function floats_tensor (line 38) | def floats_tensor(shape, scale=1.0, rng=None, name=None):
function slow (line 54) | def slow(test_case):
FILE: models/edict/my_diffusers/training_utils.py
function enable_full_determinism (line 9) | def enable_full_determinism(seed: int):
function set_seed (line 29) | def set_seed(seed: int):
class EMAModel (line 42) | class EMAModel:
method __init__ (line 47) | def __init__(
method get_decay (line 84) | def get_decay(self, optimization_step):
method step (line 97) | def step(self, new_model):
FILE: models/edict/my_diffusers/utils/dummy_scipy_objects.py
class LMSDiscreteScheduler (line 7) | class LMSDiscreteScheduler(metaclass=DummyObject):
method __init__ (line 10) | def __init__(self, *args, **kwargs):
FILE: models/edict/my_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
class GradTTSPipeline (line 6) | class GradTTSPipeline(metaclass=DummyObject):
method __init__ (line 9) | def __init__(self, *args, **kwargs):
FILE: models/edict/my_diffusers/utils/dummy_transformers_and_onnx_objects.py
class StableDiffusionOnnxPipeline (line 7) | class StableDiffusionOnnxPipeline(metaclass=DummyObject):
method __init__ (line 10) | def __init__(self, *args, **kwargs):
FILE: models/edict/my_diffusers/utils/dummy_transformers_objects.py
class LDMTextToImagePipeline (line 7) | class LDMTextToImagePipeline(metaclass=DummyObject):
method __init__ (line 10) | def __init__(self, *args, **kwargs):
class StableDiffusionImg2ImgPipeline (line 14) | class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
method __init__ (line 17) | def __init__(self, *args, **kwargs):
class StableDiffusionInpaintPipeline (line 21) | class StableDiffusionInpaintPipeline(metaclass=DummyObject):
method __init__ (line 24) | def __init__(self, *args, **kwargs):
class StableDiffusionPipeline (line 28) | class StableDiffusionPipeline(metaclass=DummyObject):
method __init__ (line 31) | def __init__(self, *args, **kwargs):
FILE: models/edict/my_diffusers/utils/import_utils.py
function is_torch_available (line 155) | def is_torch_available():
function is_tf_available (line 159) | def is_tf_available():
function is_flax_available (line 163) | def is_flax_available():
function is_transformers_available (line 167) | def is_transformers_available():
function is_inflect_available (line 171) | def is_inflect_available():
function is_unidecode_available (line 175) | def is_unidecode_available():
function is_modelcards_available (line 179) | def is_modelcards_available():
function is_onnx_available (line 183) | def is_onnx_available():
function is_scipy_available (line 187) | def is_scipy_available():
function requires_backends (line 254) | def requires_backends(obj, backends):
class DummyObject (line 265) | class DummyObject(type):
method __getattr__ (line 271) | def __getattr__(cls, key):
FILE: models/edict/my_diffusers/utils/logging.py
function _get_default_logging_level (line 50) | def _get_default_logging_level():
function _get_library_name (line 67) | def _get_library_name() -> str:
function _get_library_root_logger (line 72) | def _get_library_root_logger() -> logging.Logger:
function _configure_library_root_logger (line 77) | def _configure_library_root_logger() -> None:
function _reset_library_root_logger (line 95) | def _reset_library_root_logger() -> None:
function get_log_levels_dict (line 109) | def get_log_levels_dict():
function get_logger (line 113) | def get_logger(name: Optional[str] = None) -> logging.Logger:
function get_verbosity (line 127) | def get_verbosity() -> int:
function set_verbosity (line 150) | def set_verbosity(verbosity: int) -> None:
function set_verbosity_info (line 169) | def set_verbosity_info():
function set_verbosity_warning (line 174) | def set_verbosity_warning():
function set_verbosity_debug (line 179) | def set_verbosity_debug():
function set_verbosity_error (line 184) | def set_verbosity_error():
function disable_default_handler (line 189) | def disable_default_handler() -> None:
function enable_default_handler (line 198) | def enable_default_handler() -> None:
function add_handler (line 207) | def add_handler(handler: logging.Handler) -> None:
function remove_handler (line 216) | def remove_handler(handler: logging.Handler) -> None:
function disable_propagation (line 225) | def disable_propagation() -> None:
function enable_propagation (line 234) | def enable_propagation() -> None:
function enable_explicit_format (line 244) | def enable_explicit_format() -> None:
function reset_format (line 259) | def reset_format() -> None:
function warning_advice (line 271) | def warning_advice(self, *args, **kwargs):
class EmptyTqdm (line 285) | class EmptyTqdm:
method __init__ (line 288) | def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
method __iter__ (line 291) | def __iter__(self):
method __getattr__ (line 294) | def __getattr__(self, _):
method __enter__ (line 302) | def __enter__(self):
method __exit__ (line 305) | def __exit__(self, type_, value, traceback):
class _tqdm_cls (line 309) | class _tqdm_cls:
method __call__ (line 310) | def __call__(self, *args, **kwargs):
method set_lock (line 316) | def set_lock(self, *args, **kwargs):
method get_lock (line 321) | def get_lock(self):
function is_progress_bar_enabled (line 329) | def is_progress_bar_enabled() -> bool:
function enable_progress_bar (line 335) | def enable_progress_bar():
function disable_progress_bar (line 341) | def disable_progress_bar():
FILE: models/edict/my_diffusers/utils/outputs.py
function is_tensor (line 28) | def is_tensor(x):
class BaseOutput (line 41) | class BaseOutput(OrderedDict):
method __post_init__ (line 55) | def __post_init__(self):
method __delitem__ (line 67) | def __delitem__(self, *args, **kwargs):
method setdefault (line 70) | def setdefault(self, *args, **kwargs):
method pop (line 73) | def pop(self, *args, **kwargs):
method update (line 76) | def update(self, *args, **kwargs):
method __getitem__ (line 79) | def __getitem__(self, k):
method __setattr__ (line 93) | def __setattr__(self, name, value):
method __setitem__ (line 99) | def __setitem__(self, key, value):
method to_tuple (line 105) | def to_tuple(self) -> Tuple[Any]:
FILE: models/edit_friendly_ddm/inversion_utils.py
function load_real_image (line 4) | def load_real_image(folder = "data/", img_name = None, idx = 0, img_size...
function mu_tilde (line 22) | def mu_tilde(model, xt,x0, timestep):
function sample_xts_from_x0 (line 31) | def sample_xts_from_x0(model, x0, num_inference_steps=50):
function encode_text (line 58) | def encode_text(model, prompts):
function forward_step (line 70) | def forward_step(model, model_output, timestep, sample):
function get_variance (line 91) | def get_variance(model, timestep): #, prev_timestep):
function inversion_forward_process (line 100) | def inversion_forward_process(model, x0,
function reverse_step (line 179) | def reverse_step(model, model_output, timestep, sample, eta = 0, varianc...
function inversion_reverse_process (line 210) | def inversion_reverse_process(model,
FILE: models/edit_friendly_ddm/ptp_classes.py
class LocalBlend (line 20) | class LocalBlend:
method __call__ (line 22) | def __call__(self, x_t, attention_store):
method __init__ (line 36) | def __init__(self, prompts: List[str], words: [List[List[str]]], thres...
class AttentionControl (line 48) | class AttentionControl(abc.ABC):
method step_callback (line 50) | def step_callback(self, x_t):
method between_steps (line 53) | def between_steps(self):
method num_uncond_att_layers (line 57) | def num_uncond_att_layers(self):
method forward (line 61) | def forward (self, attn, is_cross: bool, place_in_unet: str):
method __call__ (line 64) | def __call__(self, attn, is_cross: bool, place_in_unet: str):
method reset (line 78) | def reset(self):
method __init__ (line 82) | def __init__(self):
class EmptyControl (line 87) | class EmptyControl(AttentionControl):
method forward (line 89) | def forward (self, attn, is_cross: bool, place_in_unet: str):
class AttentionStore (line 93) | class AttentionStore(AttentionControl):
method get_empty_store (line 96) | def get_empty_store():
method forward (line 100) | def forward(self, attn, is_cross: bool, place_in_unet: str):
method between_steps (line 106) | def between_steps(self):
method get_average_attention (line 115) | def get_average_attention(self):
method reset (line 120) | def reset(self):
method __init__ (line 125) | def __init__(self):
class AttentionControlEdit (line 131) | class AttentionControlEdit(AttentionStore, abc.ABC):
method step_callback (line 133) | def step_callback(self, x_t):
method replace_self_attention (line 138) | def replace_self_attention(self, attn_base, att_replace):
method replace_cross_attention (line 145) | def replace_cross_attention(self, attn_base, att_replace):
method forward (line 148) | def forward(self, attn, is_cross: bool, place_in_unet: str):
method __init__ (line 163) | def __init__(self, prompts, num_steps: int,
class AttentionReplace (line 177) | class AttentionReplace(AttentionControlEdit):
method replace_cross_attention (line 179) | def replace_cross_attention(self, attn_base, att_replace):
method __init__ (line 182) | def __init__(self, prompts, num_steps: int, cross_replace_steps: float...
class AttentionRefine (line 188) | class AttentionRefine(AttentionControlEdit):
method replace_cross_attention (line 190) | def replace_cross_attention(self, attn_base, att_replace):
method __init__ (line 195) | def __init__(self, prompts, num_steps: int, cross_replace_steps: float...
class AttentionReweight (line 203) | class AttentionReweight(AttentionControlEdit):
method replace_cross_attention (line 205) | def replace_cross_attention(self, attn_base, att_replace):
method __init__ (line 211) | def __init__(self, prompts, num_steps: int, cross_replace_steps: float...
function get_equalizer (line 218) | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], v...
function aggregate_attention (line 231) | def aggregate_attention(attention_store: AttentionStore, res: int, from_...
function show_cross_attention (line 245) | def show_cross_attention(attention_store: AttentionStore, res: int, from...
function show_self_attention_comp (line 261) | def show_self_attention_comp(attention_store: AttentionStore, res: int, ...
function run_and_display (line 276) | def run_and_display(model, prompts, controller, latent=None, run_baselin...
function load_512 (line 284) | def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
FILE: models/edit_friendly_ddm/ptp_utils.py
function text_under_image (line 30) | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int...
function view_images (line 43) | def view_images(images, num_rows=1, offset_ratio=0.02):
function diffusion_step (line 72) | def diffusion_step(model, controller, latents, context, t, guidance_scal...
function latent2image (line 87) | def latent2image(vae, latents):
function init_latent (line 96) | def init_latent(latent, model, height, width, generator, batch_size):
function text2image_ldm (line 107) | def text2image_ldm(
function text2image_ldm_stable (line 138) | def text2image_ldm_stable(
function register_attention_control (line 184) | def register_attention_control(model, controller):
function get_word_inds (line 254) | def get_word_inds(text: str, word_place: int, tokenizer):
function update_alpha_time_word (line 275) | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, floa...
function get_time_words_attention_alpha (line 288) | def get_time_words_attention_alpha(prompts, num_steps,
FILE: models/edit_friendly_ddm/seq_aligner.py
class ScoreParams (line 23) | class ScoreParams:
method __init__ (line 25) | def __init__(self, gap, match, mismatch):
method mis_match_char (line 30) | def mis_match_char(self, x, y):
function get_matrix (line 37) | def get_matrix(size_x, size_y, gap):
function get_matrix (line 51) | def get_matrix(size_x, size_y, gap):
function get_traceback_matrix (line 58) | def get_traceback_matrix(size_x, size_y):
function global_align (line 66) | def global_align(x, y, score):
function get_aligned_sequences (line 84) | def get_aligned_sequences(x, y, trace_back):
function get_mapper (line 112) | def get_mapper(x: str, y: str, tokenizer, max_len=77):
function get_refinement_mapper (line 126) | def get_refinement_mapper(prompts, tokenizer, max_len=77):
function get_word_inds (line 136) | def get_word_inds(text: str, word_place: int, tokenizer):
function get_replacement_mapper_ (line 157) | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
function get_replacement_mapper (line 194) | def get_replacement_mapper(prompts, tokenizer, max_len=77):
FILE: models/instructpix2pix/dataset_creation/generate_img_dataset.py
function append_dims (line 29) | def append_dims(x, target_dims):
function to_d (line 37) | def to_d(x, sigma, denoised):
function get_ancestral_step (line 42) | def get_ancestral_step(sigma_from, sigma_to):
function sample_euler_ancestral (line 50) | def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0...
function load_model_from_config (line 73) | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
class CFGDenoiser (line 97) | class CFGDenoiser(nn.Module):
method __init__ (line 98) | def __init__(self, model):
method forward (line 102) | def forward(self, x, sigma, uncond, cond, cfg_scale):
function to_pil (line 110) | def to_pil(image: torch.Tensor) -> Image.Image:
function main (line 116) | def main():
FILE: models/instructpix2pix/dataset_creation/generate_txt_dataset.py
function generate (line 20) | def generate(
function main (line 57) | def main(openai_model: str, num_samples: int, num_partitions: int, parti...
FILE: models/instructpix2pix/dataset_creation/prepare_dataset.py
function main (line 8) | def main():
FILE: models/instructpix2pix/dataset_creation/prepare_for_gpt.py
function main (line 7) | def main(input_path: str, output_path: str):
FILE: models/instructpix2pix/edit_app.py
class CFGDenoiser (line 60) | class CFGDenoiser(nn.Module):
method __init__ (line 61) | def __init__(self, model):
method forward (line 65) | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_sc...
function load_model_from_config (line 76) | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
function main (line 100) | def main():
FILE: models/instructpix2pix/edit_cli.py
class CFGDenoiser (line 23) | class CFGDenoiser(nn.Module):
method __init__ (line 24) | def __init__(self, model):
method forward (line 28) | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_sc...
function load_model_from_config (line 39) | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
function main (line 63) | def main():
FILE: models/instructpix2pix/edit_dataset.py
class EditDataset (line 16) | class EditDataset(Dataset):
method __init__ (line 17) | def __init__(
method __len__ (line 48) | def __len__(self) -> int:
method __getitem__ (line 51) | def __getitem__(self, i: int) -> dict[str, Any]:
class EditDatasetEval (line 75) | class EditDatasetEval(Dataset):
method __init__ (line 76) | def __init__(
method __len__ (line 101) | def __len__(self) -> int:
method __getitem__ (line 104) | def __getitem__(self, i: int) -> dict[str, Any]:
FILE: models/instructpix2pix/main.py
function get_parser (line 30) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 130) | def nondefault_trainer_args(opt):
class WrappedDataset (line 137) | class WrappedDataset(Dataset):
method __init__ (line 140) | def __init__(self, dataset):
method __len__ (line 143) | def __len__(self):
method __getitem__ (line 146) | def __getitem__(self, idx):
function worker_init_fn (line 150) | def worker_init_fn(_):
class DataModuleFromConfig (line 166) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 167) | def __init__(self, batch_size, train=None, validation=None, test=None,...
method prepare_data (line 189) | def prepare_data(self):
method setup (line 193) | def setup(self, stage=None):
method _train_dataloader (line 201) | def _train_dataloader(self):
method _val_dataloader (line 211) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 222) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 235) | def _predict_dataloader(self, shuffle=False):
class SetupCallback (line 244) | class SetupCallback(Callback):
method __init__ (line 245) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_keyboard_interrupt (line 255) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_pretrain_routine_start (line 261) | def on_pretrain_routine_start(self, trainer, pl_module):
function get_world_size (line 281) | def get_world_size():
function all_gather (line 288) | def all_gather(data):
class ImageLogger (line 351) | class ImageLogger(Callback):
method __init__ (line 352) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
method _testtube (line 372) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 383) | def log_local(self, save_dir, split, images, prompts,
method log_img (line 414) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 450) | def check_frequency(self, check_idx):
method on_train_batch_end (line 458) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 462) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
class CUDACallback (line 470) | class CUDACallback(Callback):
method on_train_epoch_start (line 472) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 478) | def on_train_epoch_end(self, trainer, pl_module, outputs):
function melk (line 755) | def melk(*args, **kwargs):
function divein (line 763) | def divein(*args, **kwargs):
FILE: models/instructpix2pix/metrics/clip_similarity.py
class ClipSimilarity (line 10) | class ClipSimilarity(nn.Module):
method __init__ (line 11) | def __init__(self, name: str = "ViT-L/14"):
method encode_text (line 22) | def encode_text(self, text: list[str]) -> torch.Tensor:
method encode_image (line 28) | def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input ...
method forward (line 36) | def forward(
FILE: models/instructpix2pix/metrics/compute_metrics.py
class CFGDenoiser (line 34) | class CFGDenoiser(nn.Module):
method __init__ (line 35) | def __init__(self, model):
method forward (line 39) | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_sc...
function load_model_from_config (line 50) | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
class ImageEditor (line 73) | class ImageEditor(nn.Module):
method __init__ (line 74) | def __init__(self, config, ckpt, vae_ckpt=None):
method forward (line 84) | def forward(
function compute_metrics (line 117) | def compute_metrics(config,
function plot_metrics (line 186) | def plot_metrics(metrics_file, output_path):
function main (line 205) | def main():
FILE: models/instructpix2pix/prompt_app.py
function main (line 13) | def main(openai_model: str):
FILE: models/instructpix2pix/stable_diffusion/ldm/lr_scheduler.py
class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
method schedule (line 17) | def schedule(self, n, **kwargs):
method __call__ (line 32) | def __call__(self, n, **kwargs):
class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
method find_in_interval (line 52) | def find_in_interval(self, n):
method schedule (line 59) | def schedule(self, n, **kwargs):
method __call__ (line 77) | def __call__(self, n, **kwargs):
class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
method schedule (line 83) | def schedule(self, n, **kwargs):
FILE: models/instructpix2pix/stable_diffusion/ldm/models/autoencoder.py
class VQModel (line 14) | class VQModel(pl.LightningModule):
method __init__ (line 15) | def __init__(self,
method ema_scope (line 64) | def ema_scope(self, context=None):
method init_from_ckpt (line 78) | def init_from_ckpt(self, path, ignore_keys=list()):
method on_train_batch_end (line 92) | def on_train_batch_end(self, *args, **kwargs):
method encode (line 96) | def encode(self, x):
method encode_to_prequant (line 102) | def encode_to_prequant(self, x):
method decode (line 107) | def decode(self, quant):
method decode_code (line 112) | def decode_code(self, code_b):
method forward (line 117) | def forward(self, input, return_pred_indices=False):
method get_input (line 124) | def get_input(self, batch, k):
method training_step (line 142) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 164) | def validation_step(self, batch, batch_idx):
method _validation_step (line 170) | def _validation_step(self, batch, batch_idx, suffix=""):
method configure_optimizers (line 197) | def configure_optimizers(self):
method get_last_layer (line 230) | def get_last_layer(self):
method log_images (line 233) | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
method to_rgb (line 255) | def to_rgb(self, x):
class VQModelInterface (line 264) | class VQModelInterface(VQModel):
method __init__ (line 265) | def __init__(self, embed_dim, *args, **kwargs):
method encode (line 269) | def encode(self, x):
method decode (line 274) | def decode(self, h, force_not_quantize=False):
class AutoencoderKL (line 285) | class AutoencoderKL(pl.LightningModule):
method __init__ (line 286) | def __init__(self,
method init_from_ckpt (line 313) | def init_from_ckpt(self, path, ignore_keys=list()):
method encode (line 324) | def encode(self, x):
method decode (line 330) | def decode(self, z):
method forward (line 335) | def forward(self, input, sample_posterior=True):
method get_input (line 344) | def get_input(self, batch, k):
method training_step (line 351) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 372) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 386) | def configure_optimizers(self):
method get_last_layer (line 397) | def get_last_layer(self):
method log_images (line 401) | def log_images(self, batch, only_inputs=False, **kwargs):
method to_rgb (line 417) | def to_rgb(self, x):
class IdentityFirstStage (line 426) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 427) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 431) | def encode(self, x, *args, **kwargs):
method decode (line 434) | def decode(self, x, *args, **kwargs):
method quantize (line 437) | def quantize(self, x, *args, **kwargs):
method forward (line 442) | def forward(self, x, *args, **kwargs):
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/classifier.py
function disabled_train (line 22) | def disabled_train(self, mode=True):
class NoisyLatentImageClassifier (line 28) | class NoisyLatentImageClassifier(pl.LightningModule):
method __init__ (line 30) | def __init__(self,
method init_from_ckpt (line 70) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method load_diffusion (line 88) | def load_diffusion(self):
method load_classifier (line 95) | def load_classifier(self, ckpt_path, pool):
method get_x_noisy (line 110) | def get_x_noisy(self, x, t, noise=None):
method forward (line 120) | def forward(self, x_noisy, t, *args, **kwargs):
method get_input (line 124) | def get_input(self, batch, k):
method get_conditioning (line 133) | def get_conditioning(self, batch, k=None):
method compute_top_k (line 150) | def compute_top_k(self, logits, labels, k, reduction="mean"):
method on_train_epoch_start (line 157) | def on_train_epoch_start(self):
method write_logs (line 162) | def write_logs(self, loss, logits, targets):
method shared_step (line 179) | def shared_step(self, batch, t=None):
method training_step (line 198) | def training_step(self, batch, batch_idx):
method reset_noise_accs (line 202) | def reset_noise_accs(self):
method on_validation_start (line 206) | def on_validation_start(self):
method validation_step (line 210) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 220) | def configure_optimizers(self):
method log_images (line 238) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/ddim.py
class DDIMSampler (line 12) | class DDIMSampler(object):
method __init__ (line 13) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 19) | def register_buffer(self, name, attr):
method make_schedule (line 25) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 57) | def sample(self,
method ddim_sampling (line 114) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 166) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method stochastic_encode (line 207) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 223) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/ddpm.py
function disabled_train (line 34) | def disabled_train(self, mode=True):
function uniform_on_device (line 40) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 44) | class DDPM(pl.LightningModule):
method __init__ (line 46) | def __init__(self,
method register_schedule (line 117) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 172) | def ema_scope(self, context=None):
method init_from_ckpt (line 186) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 204) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 216) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 222) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 231) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 244) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 253) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 268) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 274) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 279) | def get_loss(self, pred, target, mean=True):
method p_losses (line 294) | def p_losses(self, x_start, t, noise=None):
method forward (line 323) | def forward(self, x, *args, **kwargs):
method get_input (line 329) | def get_input(self, batch, k):
method shared_step (line 337) | def shared_step(self, batch):
method training_step (line 342) | def training_step(self, batch, batch_idx):
method validation_step (line 358) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 366) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 370) | def _get_rows_from_list(self, samples):
method log_images (line 378) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 415) | def configure_optimizers(self):
class LatentDiffusion (line 424) | class LatentDiffusion(DDPM):
method __init__ (line 426) | def __init__(self,
method make_cond_schedule (line 471) | def make_cond_schedule(self, ):
method on_train_batch_start (line 478) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 493) | def register_schedule(self,
method instantiate_first_stage (line 502) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 509) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 530) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 542) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 551) | def get_learned_conditioning(self, c):
method meshgrid (line 564) | def meshgrid(self, h, w):
method delta_border (line 571) | def delta_border(self, h, w):
method get_weighting (line 585) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 601) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 654) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 706) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method differentiable_decode_first_stage (line 766) | def differentiable_decode_first_stage(self, z, predict_cids=False, for...
method encode_first_stage (line 826) | def encode_first_stage(self, x):
method shared_step (line 865) | def shared_step(self, batch, **kwargs):
method forward (line 870) | def forward(self, x, c, *args, **kwargs):
method _rescale_annotations (line 881) | def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: mov...
method apply_model (line 891) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 994) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 998) | def _prior_bpd(self, x_start):
method p_losses (line 1012) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 1047) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1079) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1110) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1166) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1217) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1235) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
method log_images (line 1251) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 1361) | def configure_optimizers(self):
method to_rgb (line 1386) | def to_rgb(self, x):
class DiffusionWrapper (line 1395) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1396) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1402) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
class Layout2ImgDiffusion (line 1424) | class Layout2ImgDiffusion(LatentDiffusion):
method __init__ (line 1426) | def __init__(self, cond_stage_key, *args, **kwargs):
method log_images (line 1430) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/ddpm_edit.py
function disabled_train (line 37) | def disabled_train(self, mode=True):
function uniform_on_device (line 43) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 47) | class DDPM(pl.LightningModule):
method __init__ (line 49) | def __init__(self,
method register_schedule (line 128) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 183) | def ema_scope(self, context=None):
method init_from_ckpt (line 197) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 236) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 248) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 254) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 263) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 276) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 285) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 300) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 306) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 311) | def get_loss(self, pred, target, mean=True):
method p_losses (line 326) | def p_losses(self, x_start, t, noise=None):
method forward (line 355) | def forward(self, x, *args, **kwargs):
method get_input (line 361) | def get_input(self, batch, k):
method shared_step (line 364) | def shared_step(self, batch):
method training_step (line 369) | def training_step(self, batch, batch_idx):
method validation_step (line 385) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 393) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 397) | def _get_rows_from_list(self, samples):
method log_images (line 405) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 442) | def configure_optimizers(self):
class LatentDiffusion (line 451) | class LatentDiffusion(DDPM):
method __init__ (line 453) | def __init__(self,
method make_cond_schedule (line 503) | def make_cond_schedule(self, ):
method on_train_batch_start (line 510) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 525) | def register_schedule(self,
method instantiate_first_stage (line 534) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 541) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 562) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 574) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 583) | def get_learned_conditioning(self, c):
method meshgrid (line 596) | def meshgrid(self, h, w):
method delta_border (line 603) | def delta_border(self, h, w):
method get_weighting (line 617) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 633) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 686) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 719) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method differentiable_decode_first_stage (line 779) | def differentiable_decode_first_stage(self, z, predict_cids=False, for...
method encode_first_stage (line 839) | def encode_first_stage(self, x):
method shared_step (line 878) | def shared_step(self, batch, **kwargs):
method forward (line 883) | def forward(self, x, c, *args, **kwargs):
method _rescale_annotations (line 894) | def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: mov...
method apply_model (line 904) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 1007) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 1011) | def _prior_bpd(self, x_start):
method p_losses (line 1025) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 1060) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1092) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1123) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1179) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1230) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1248) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
method log_images (line 1264) | def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 1375) | def configure_optimizers(self):
method to_rgb (line 1400) | def to_rgb(self, x):
class DiffusionWrapper (line 1409) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1410) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1416) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
class Layout2ImgDiffusion (line 1438) | class Layout2ImgDiffusion(LatentDiffusion):
method __init__ (line 1440) | def __init__(self, cond_stage_key, *args, **kwargs):
method log_images (line 1444) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py
class NoiseScheduleVP (line 6) | class NoiseScheduleVP:
method __init__ (line 7) | def __init__(
method marginal_log_mean_coeff (line 125) | def marginal_log_mean_coeff(self, t):
method marginal_alpha (line 138) | def marginal_alpha(self, t):
method marginal_std (line 144) | def marginal_std(self, t):
method marginal_lambda (line 150) | def marginal_lambda(self, t):
method inverse_lambda (line 158) | def inverse_lambda(self, lamb):
function model_wrapper (line 177) | def model_wrapper(
class DPM_Solver (line 351) | class DPM_Solver:
method __init__ (line 352) | def __init__(self, model_fn, noise_schedule, predict_x0=False, thresho...
method noise_prediction_fn (line 380) | def noise_prediction_fn(self, x, t):
method data_prediction_fn (line 386) | def data_prediction_fn(self, x, t):
method model_fn (line 401) | def model_fn(self, x, t):
method get_time_steps (line 410) | def get_time_steps(self, skip_type, t_T, t_0, N, device):
method get_orders_and_timesteps_for_singlestep_solver (line 439) | def get_orders_and_timesteps_for_singlestep_solver(self, steps, order,...
method denoise_to_zero_fn (line 498) | def denoise_to_zero_fn(self, x, s):
method dpm_solver_first_update (line 504) | def dpm_solver_first_update(self, x, s, t, model_s=None, return_interm...
method singlestep_dpm_solver_second_update (line 551) | def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s...
method singlestep_dpm_solver_third_update (line 633) | def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./...
method multistep_dpm_solver_second_update (line 755) | def multistep_dpm_solver_second_update(self, x, model_prev_list, t_pre...
method multistep_dpm_solver_third_update (line 812) | def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev...
method singlestep_dpm_solver_update (line 859) | def singlestep_dpm_solver_update(self, x, s, t, order, return_intermed...
method multistep_dpm_solver_update (line 885) | def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list,...
method dpm_solver_adaptive (line 909) | def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0....
method sample (line 965) | def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_...
function interpolate_fn (line 1132) | def interpolate_fn(x, xp, yp):
function expand_dims (line 1174) | def expand_dims(v, dims):
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/dpm_solver/sampler.py
class DPMSolverSampler (line 8) | class DPMSolverSampler(object):
method __init__ (line 9) | def __init__(self, model, **kwargs):
method register_buffer (line 15) | def register_buffer(self, name, attr):
method sample (line 22) | def sample(self,
FILE: models/instructpix2pix/stable_diffusion/ldm/models/diffusion/plms.py
class PLMSSampler (line 11) | class PLMSSampler(object):
method __init__ (line 12) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 18) | def register_buffer(self, name, attr):
method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 58) | def sample(self,
method plms_sampling (line 115) | def plms_sampling(self, cond, shape,
method p_sample_plms (line 173) | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_origin...
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/attention.py
function exists (line 14) | def exists(val):
function uniq (line 18) | def uniq(arr):
function default (line 22) | def default(val, d):
function max_neg_value (line 28) | def max_neg_value(t):
function init_ (line 32) | def init_(tensor):
class GEGLU (line 40) | class GEGLU(nn.Module):
method __init__ (line 41) | def __init__(self, dim_in, dim_out):
method forward (line 45) | def forward(self, x):
class FeedForward (line 50) | class FeedForward(nn.Module):
method __init__ (line 51) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 66) | def forward(self, x):
function zero_module (line 70) | def zero_module(module):
function Normalize (line 79) | def Normalize(in_channels):
class LinearAttention (line 83) | class LinearAttention(nn.Module):
method __init__ (line 84) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 91) | def forward(self, x):
class SpatialSelfAttention (line 102) | class SpatialSelfAttention(nn.Module):
method __init__ (line 103) | def __init__(self, in_channels):
method forward (line 129) | def forward(self, x):
class CrossAttention (line 155) | class CrossAttention(nn.Module):
method __init__ (line 156) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 175) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 210) | class BasicTransformerBlock(nn.Module):
method __init__ (line 211) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 222) | def forward(self, x, context=None):
method _forward (line 225) | def _forward(self, x, context=None):
class SpatialTransformer (line 232) | class SpatialTransformer(nn.Module):
method __init__ (line 240) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 264) | def forward(self, x, context=None):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 12) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 33) | def nonlinearity(x):
function Normalize (line 38) | def Normalize(in_channels, num_groups=32):
class Upsample (line 42) | class Upsample(nn.Module):
method __init__ (line 43) | def __init__(self, in_channels, with_conv):
method forward (line 53) | def forward(self, x):
class Downsample (line 60) | class Downsample(nn.Module):
method __init__ (line 61) | def __init__(self, in_channels, with_conv):
method forward (line 72) | def forward(self, x):
class ResnetBlock (line 82) | class ResnetBlock(nn.Module):
method __init__ (line 83) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 121) | def forward(self, x, temb):
class LinAttnBlock (line 144) | class LinAttnBlock(LinearAttention):
method __init__ (line 146) | def __init__(self, in_channels):
class AttnBlock (line 150) | class AttnBlock(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels):
method forward (line 178) | def forward(self, x):
function make_attn (line 205) | def make_attn(in_channels, attn_type="vanilla"):
class Model (line 216) | class Model(nn.Module):
method __init__ (line 217) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 316) | def forward(self, x, t=None, context=None):
method get_last_layer (line 364) | def get_last_layer(self):
class Encoder (line 368) | class Encoder(nn.Module):
method __init__ (line 369) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 434) | def forward(self, x):
class Decoder (line 462) | class Decoder(nn.Module):
method __init__ (line 463) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 535) | def forward(self, z):
class SimpleDecoder (line 571) | class SimpleDecoder(nn.Module):
method __init__ (line 572) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 594) | def forward(self, x):
class UpsampleDecoder (line 607) | class UpsampleDecoder(nn.Module):
method __init__ (line 608) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 641) | def forward(self, x):
class LatentRescaler (line 655) | class LatentRescaler(nn.Module):
method __init__ (line 656) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 680) | def forward(self, x):
class MergedRescaleEncoder (line 692) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 693) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 705) | def forward(self, x):
class MergedRescaleDecoder (line 711) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 712) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 722) | def forward(self, x):
class Upsampler (line 728) | class Upsampler(nn.Module):
method __init__ (line 729) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 741) | def forward(self, x):
class Resize (line 747) | class Resize(nn.Module):
method __init__ (line 748) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 763) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 770) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 772) | def __init__(self, ch_mult:list, in_channels,
method instantiate_pretrained (line 807) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 816) | def encode_with_pretrained(self,x):
method forward (line 822) | def forward(self,x):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 24) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 27) | def convert_module_to_f32(x):
class AttentionPool2d (line 32) | class AttentionPool2d(nn.Module):
method __init__ (line 37) | def __init__(
method forward (line 51) | def forward(self, x):
class TimestepBlock (line 62) | class TimestepBlock(nn.Module):
method forward (line 68) | def forward(self, x, emb):
class TimestepEmbedSequential (line 74) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 80) | def forward(self, x, emb, context=None):
class Upsample (line 91) | class Upsample(nn.Module):
method __init__ (line 100) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 109) | def forward(self, x):
class TransposedUpsample (line 121) | class TransposedUpsample(nn.Module):
method __init__ (line 123) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 130) | def forward(self,x):
class Downsample (line 134) | class Downsample(nn.Module):
method __init__ (line 143) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
method forward (line 158) | def forward(self, x):
class ResBlock (line 163) | class ResBlock(TimestepBlock):
method __init__ (line 179) | def __init__(
method forward (line 243) | def forward(self, x, emb):
method _forward (line 255) | def _forward(self, x, emb):
class AttentionBlock (line 278) | class AttentionBlock(nn.Module):
method __init__ (line 285) | def __init__(
method forward (line 314) | def forward(self, x):
method _forward (line 318) | def _forward(self, x):
function count_flops_attn (line 327) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 347) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 352) | def __init__(self, n_heads):
method forward (line 356) | def forward(self, qkv):
method count_flops (line 375) | def count_flops(model, _x, y):
class QKVAttention (line 379) | class QKVAttention(nn.Module):
method __init__ (line 384) | def __init__(self, n_heads):
method forward (line 388) | def forward(self, qkv):
method count_flops (line 409) | def count_flops(model, _x, y):
class UNetModel (line 413) | class UNetModel(nn.Module):
method __init__ (line 443) | def __init__(
method convert_to_fp16 (line 694) | def convert_to_fp16(self):
method convert_to_fp32 (line 702) | def convert_to_fp32(self):
method forward (line 710) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
class EncoderUNetModel (line 745) | class EncoderUNetModel(nn.Module):
method __init__ (line 751) | def __init__(
method convert_to_fp16 (line 924) | def convert_to_fp16(self):
method convert_to_fp32 (line 931) | def convert_to_fp32(self):
method forward (line 938) | def forward(self, x, timesteps):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
method forward (line 121) | def forward(ctx, run_function, length, *args):
method backward (line 131) | def backward(ctx, *output_grads):
function timestep_embedding (line 151) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 174) | def zero_module(module):
function scale_module (line 183) | def scale_module(module, scale):
function mean_flat (line 192) | def mean_flat(tensor):
function normalization (line 199) | def normalization(channels):
class SiLU (line 209) | class SiLU(nn.Module):
method forward (line 210) | def forward(self, x):
class GroupNorm32 (line 214) | class GroupNorm32(nn.GroupNorm):
method forward (line 215) | def forward(self, x):
function conv_nd (line 218) | def conv_nd(dims, *args, **kwargs):
function linear (line 231) | def linear(*args, **kwargs):
function avg_pool_nd (line 238) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 251) | class HybridConditioner(nn.Module):
method __init__ (line 253) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 258) | def forward(self, c_concat, c_crossattn):
function noise_like (line 264) | def noise_like(shape, device, repeat=False):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self):
method kl (line 39) | def kl(self, other=None):
method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
method mode (line 61) | def mode(self):
function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 25) | def forward(self,model):
method copy_to (line 46) | def copy_to(self, model):
method store (line 55) | def store(self, parameters):
method restore (line 64) | def restore(self, parameters):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/encoders/modules.py
class AbstractEncoder (line 12) | class AbstractEncoder(nn.Module):
method __init__ (line 13) | def __init__(self):
method encode (line 16) | def encode(self, *args, **kwargs):
class ClassEmbedder (line 21) | class ClassEmbedder(nn.Module):
method __init__ (line 22) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 27) | def forward(self, batch, key=None):
class TransformerEmbedder (line 36) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 38) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 44) | def forward(self, tokens):
method encode (line 49) | def encode(self, x):
class BERTTokenizer (line 53) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 55) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 63) | def forward(self, text):
method encode (line 70) | def encode(self, text):
method decode (line 76) | def decode(self, text):
class BERTEmbedder (line 80) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 82) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 93) | def forward(self, text):
method encode (line 101) | def encode(self, text):
class SpatialRescaler (line 106) | class SpatialRescaler(nn.Module):
method __init__ (line 107) | def __init__(self,
method forward (line 125) | def forward(self,x):
method encode (line 134) | def encode(self, x):
class FrozenCLIPEmbedder (line 137) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 139) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 147) | def freeze(self):
method forward (line 152) | def forward(self, text):
method encode (line 161) | def encode(self, text):
class FrozenCLIPTextEmbedder (line 165) | class FrozenCLIPTextEmbedder(nn.Module):
method __init__ (line 169) | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n...
method freeze (line 177) | def freeze(self):
method forward (line 182) | def forward(self, text):
method encode (line 189) | def encode(self, text):
class FrozenClipImageEmbedder (line 197) | class FrozenClipImageEmbedder(nn.Module):
method __init__ (line 201) | def __init__(
method preprocess (line 216) | def preprocess(self, x):
method forward (line 226) | def forward(self, x):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 339) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 343) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 373) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 390) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 408) | def add_Poisson_noise(img):
function add_JPEG_noise (line 422) | def add_JPEG_noise(img):
function random_crop (line 431) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 442) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 534) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/image_degradation/utils_image.py
function is_image_file (line 29) | def is_image_file(filename):
function get_timestamp (line 33) | def get_timestamp():
function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
function get_image_paths (line 67) | def get_image_paths(dataroot):
function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
function imssave (line 112) | def imssave(imgs, img_path):
function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
function mkdir (line 153) | def mkdir(path):
function mkdirs (line 158) | def mkdirs(paths):
function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
function imread_uint (line 185) | def imread_uint(path, n_channels=3):
function imsave (line 203) | def imsave(img, img_path):
function imwrite (line 209) | def imwrite(img, img_path):
function read_img (line 220) | def read_img(path):
function uint2single (line 249) | def uint2single(img):
function single2uint (line 254) | def single2uint(img):
function uint162single (line 259) | def uint162single(img):
function single2uint16 (line 264) | def single2uint16(img):
function uint2tensor4 (line 275) | def uint2tensor4(img):
function uint2tensor3 (line 282) | def uint2tensor3(img):
function tensor2uint (line 289) | def tensor2uint(img):
function single2tensor3 (line 302) | def single2tensor3(img):
function single2tensor4 (line 307) | def single2tensor4(img):
function tensor2single (line 312) | def tensor2single(img):
function tensor2single3 (line 320) | def tensor2single3(img):
function single2tensor5 (line 329) | def single2tensor5(img):
function single32tensor5 (line 333) | def single32tensor5(img):
function single42tensor4 (line 337) | def single42tensor4(img):
function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
function augment_img (line 380) | def augment_img(img, mode=0):
function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
function modcrop (line 494) | def modcrop(img_in, scale):
function shave (line 510) | def shave(img_in, border=0):
function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
function ssim (line 669) | def ssim(img1, img2):
function cubic (line 700) | def cubic(x):
function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 766) | def imresize(img, scale, antialiasing=True):
function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/losses/contperceptual.py
class LPIPSWithDiscriminator (line 7) | class LPIPSWithDiscriminator(nn.Module):
method __init__ (line 8) | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixello...
method calculate_adaptive_weight (line 32) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 45) | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/losses/vqperceptual.py
function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
function l1 (line 35) | def l1(x, y):
function l2 (line 39) | def l2(x, y):
class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
FILE: models/instructpix2pix/stable_diffusion/ldm/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 25) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 26) | def __init__(self, dim, max_seq_len):
method init_ (line 31) | def init_(self):
method forward (line 34) | def forward(self, x):
class FixedPositionalEmbedding (line 39) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 40) | def __init__(self, dim):
method forward (line 45) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 54) | def exists(val):
function default (line 58) | def default(val, d):
function always (line 64) | def always(val):
function not_equals (line 70) | def not_equals(val):
function equals (line 76) | def equals(val):
function max_neg_value (line 82) | def max_neg_value(tensor):
function pick_and_pop (line 88) | def pick_and_pop(keys, d):
function group_dict_by_key (line 93) | def group_dict_by_key(cond, d):
function string_begins_with (line 102) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 106) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 110) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 117) | class Scale(nn.Module):
method __init__ (line 118) | def __init__(self, value, fn):
method forward (line 123) | def forward(self, x, **kwargs):
class Rezero (line 128) | class Rezero(nn.Module):
method __init__ (line 129) | def __init__(self, fn):
method forward (line 134) | def forward(self, x, **kwargs):
class ScaleNorm (line 139) | class ScaleNorm(nn.Module):
method __init__ (line 140) | def __init__(self, dim, eps=1e-5):
method forward (line 146) | def forward(self, x):
class RMSNorm (line 151) | class RMSNorm(nn.Module):
method __init__ (line 152) | def __init__(self, dim, eps=1e-8):
method forward (line 158) | def forward(self, x):
class Residual (line 163) | class Residual(nn.Module):
method forward (line 164) | def forward(self, x, residual):
class GRUGating (line 168) | class GRUGating(nn.Module):
method __init__ (line 169) | def __init__(self, dim):
method forward (line 173) | def forward(self, x, residual):
class GEGLU (line 184) | class GEGLU(nn.Module):
method __init__ (line 185) | def __init__(self, dim_in, dim_out):
method forward (line 189) | def forward(self, x):
class FeedForward (line 194) | class FeedForward(nn.Module):
method __init__ (line 195) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 210) | def forward(self, x):
class Attention (line 215) | class Attention(nn.Module):
method __init__ (line 216) | def __init__(
method forward (line 268) | def forward(
class AttentionLayers (line 370) | class AttentionLayers(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 481) | def forward(
class Encoder (line 541) | class Encoder(AttentionLayers):
method __init__ (line 542) | def __init__(self, **kwargs):
class TransformerWrapper (line 548) | class TransformerWrapper(nn.Module):
method __init__ (line 549) | def __init__(
method init_ (line 595) | def init_(self):
method forward (line 598) | def forward(
FILE: models/instructpix2pix/stable_diffusion/ldm/util.py
function log_txt_as_img (line 17) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 41) | def ismap(x):
function isimage (line 47) | def isimage(x):
function exists (line 53) | def exists(x):
function default (line 57) | def default(val, d):
function mean_flat (line 63) | def mean_flat(tensor):
function count_params (line 71) | def count_params(model, verbose=False):
function instantiate_from_config (line 78) | def instantiate_from_config(config):
function get_obj_from_str (line 88) | def get_obj_from_str(string, reload=False):
function _do_parallel_data_prefetch (line 96) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
function parallel_data_prefetch (line 108) | def parallel_data_prefetch(
FILE: models/instructpix2pix/stable_diffusion/main.py
function get_parser (line 24) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 126) | def nondefault_trainer_args(opt):
class WrappedDataset (line 133) | class WrappedDataset(Dataset):
method __init__ (line 136) | def __init__(self, dataset):
method __len__ (line 139) | def __len__(self):
method __getitem__ (line 142) | def __getitem__(self, idx):
function worker_init_fn (line 146) | def worker_init_fn(_):
class DataModuleFromConfig (line 162) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 163) | def __init__(self, batch_size, train=None, validation=None, test=None,...
method prepare_data (line 185) | def prepare_data(self):
method setup (line 189) | def setup(self, stage=None):
method _train_dataloader (line 197) | def _train_dataloader(self):
method _val_dataloader (line 207) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 218) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 231) | def _predict_dataloader(self, shuffle=False):
class SetupCallback (line 240) | class SetupCallback(Callback):
method __init__ (line 241) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_keyboard_interrupt (line 251) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_pretrain_routine_start (line 257) | def on_pretrain_routine_start(self, trainer, pl_module):
class ImageLogger (line 289) | class ImageLogger(Callback):
method __init__ (line 290) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
method _testtube (line 310) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 321) | def log_local(self, save_dir, split, images,
method log_img (line 340) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 372) | def check_frequency(self, check_idx):
method on_train_batch_end (line 383) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 387) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
class CUDACallback (line 395) | class CUDACallback(Callback):
method on_train_epoch_start (line 397) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 403) | def on_train_epoch_end(self, trainer, pl_module, outputs):
function melk (line 697) | def melk(*args, **kwargs):
function divein (line 705) | def divein(*args, **kwargs):
FILE: models/instructpix2pix/stable_diffusion/notebook_helpers.py
function download_models (line 19) | def download_models(mode):
function load_model_from_config (line 40) | def load_model_from_config(config, ckpt):
function get_model (line 52) | def get_model(mode):
function get_custom_cond (line 59) | def get_custom_cond(mode):
function get_cond_options (line 85) | def get_cond_options(mode):
function select_cond_path (line 92) | def select_cond_path(mode):
function get_cond (line 107) | def get_cond(mode, selected_path):
function visualize_cond_img (line 127) | def visualize_cond_img(path):
function run (line 131) | def run(model, selected_path, task, custom_steps, resize_enabled=False, ...
function convsample_ddim (line 188) | def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, n...
function make_convolutional_sample (line 208) | def make_convolutional_sample(batch, model, mode="vanilla", custom_steps...
FILE: models/instructpix2pix/stable_diffusion/scripts/img2img.py
function chunk (line 23) | def chunk(it, size):
function load_model_from_config (line 28) | def load_model_from_config(config, ckpt, verbose=False):
function load_img (line 48) | def load_img(path):
function main (line 60) | def main():
FILE: models/instructpix2pix/stable_diffusion/scripts/inpaint.py
function make_batch (line 11) | def make_batch(image, mask, device):
FILE: models/instructpix2pix/stable_diffusion/scripts/knn2img.py
function chunk (line 36) | def chunk(it, size):
function load_model_from_config (line 41) | def load_model_from_config(config, ckpt, verbose=False):
class Searcher (line 61) | class Searcher(object):
method __init__ (line 62) | def __init__(self, database, retriever_version='ViT-L/14'):
method train_searcher (line 75) | def train_searcher(self, k,
method load_single_file (line 91) | def load_single_file(self, saved_embeddings):
method load_multi_files (line 96) | def load_multi_files(self, data_archive):
method load_database (line 104) | def load_database(self):
method load_retriever (line 123) | def load_retriever(self, version='ViT-L/14', ):
method load_searcher (line 130) | def load_searcher(self):
method search (line 135) | def search(self, x, k):
method __call__ (line 163) | def __call__(self, x, n):
FILE: models/instructpix2pix/stable_diffusion/scripts/sample_diffusion.py
function custom_to_pil (line 15) | def custom_to_pil(x):
function custom_to_np (line 27) | def custom_to_np(x):
function logs2pil (line 36) | def logs2pil(logs, keys=["sample"]):
function convsample (line 54) | def convsample(model, shape, return_intermediates=True,
function convsample_ddim (line 69) | def convsample_ddim(model, steps, shape, eta=1.0
function make_convolutional_sample (line 79) | def make_convolutional_sample(model, batch_size, vanilla=False, custom_s...
function run (line 108) | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, ...
function save_logs (line 143) | def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
function get_parser (line 162) | def get_parser():
function load_model_from_config (line 220) | def load_model_from_config(config, sd):
function load_model (line 228) | def load_model(config, ckpt, gpu, eval_mode):
FILE: models/instructpix2pix/stable_diffusion/scripts/tests/test_watermark.py
function testit (line 6) | def testit(img_path):
FILE: models/instructpix2pix/stable_diffusion/scripts/train_searcher.py
function search_bruteforce (line 12) | def search_bruteforce(searcher):
function search_partioned_ah (line 16) | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
function search_ah (line 24) | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
function load_datapool (line 28) | def load_datapool(dpath):
function train_searcher (line 62) | def train_searcher(opt,
FILE: models/instructpix2pix/stable_diffusion/scripts/txt2img.py
function chunk (line 32) | def chunk(it, size):
function numpy_to_pil (line 37) | def numpy_to_pil(images):
function load_model_from_config (line 49) | def load_model_from_config(config, ckpt, verbose=False):
function put_watermark (line 69) | def put_watermark(img, wm_encoder=None):
function load_replacement (line 77) | def load_replacement(x):
function check_safety (line 88) | def check_safety(x_image):
function main (line 98) | def main():
FILE: models/masactrl/diffuser_utils.py
class MasaCtrlPipeline (line 14) | class MasaCtrlPipeline(StableDiffusionPipeline):
method next_step (line 16) | def next_step(
method step (line 39) | def step(
method image2latent (line 60) | def image2latent(self, image):
method latent2image (line 72) | def latent2image(self, latents, return_type='np'):
method latent2image_grad (line 84) | def latent2image_grad(self, latents):
method __call__ (line 91) | def __call__(
method invert (line 196) | def invert(
FILE: models/masactrl/masactrl.py
class MutualSelfAttentionControl (line 14) | class MutualSelfAttentionControl(AttentionBase):
method __init__ (line 20) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_...
method attn_batch (line 41) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method forward (line 56) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttentionControlUnion (line 75) | class MutualSelfAttentionControlUnion(MutualSelfAttentionControl):
method __init__ (line 76) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_...
method forward (line 89) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttentionControlMask (line 114) | class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
method __init__ (line 115) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step...
method attn_batch (line 138) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method forward (line 163) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
class MutualSelfAttentionControlMaskAuto (line 196) | class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
method __init__ (line 197) | def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_...
method after_step (line 227) | def after_step(self):
method attn_batch (line 231) | def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_...
method aggregate_cross_attn_map (line 260) | def aggregate_cross_attn_map(self, idx):
method forward (line 273) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
FILE: models/masactrl/masactrl_utils.py
class AttentionBase (line 14) | class AttentionBase:
method __init__ (line 15) | def __init__(self):
method after_step (line 20) | def after_step(self):
method __call__ (line 23) | def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_he...
method forward (line 33) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
method reset (line 38) | def reset(self):
class AttentionStore (line 43) | class AttentionStore(AttentionBase):
method __init__ (line 44) | def __init__(self, res=[32], min_step=0, max_step=1000):
method after_step (line 57) | def after_step(self):
method forward (line 70) | def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_hea...
function regiter_attention_editor_diffusers (line 79) | def regiter_attention_editor_diffusers(model, editor: AttentionBase):
function regiter_attention_editor_ldm (line 147) | def regiter_attention_editor_ldm(model, editor: AttentionBase):
FILE: models/p2p/attention_control.py
function register_attention_control (line 12) | def register_attention_control(model, controller):
function get_equalizer (line 84) | def get_equalizer(text, word_select, values, tokenizer=None):
class LocalBlend (line 95) | class LocalBlend:
method get_mask (line 97) | def get_mask(self, maps, alpha, use_pool):
method __call__ (line 108) | def __call__(self, x_t, attention_store):
method __init__ (line 123) | def __init__(self, prompts, words, substruct_words=None, start_blend=0...
class EmptyControl (line 150) | class EmptyControl:
method step_callback (line 152) | def step_callback(self, x_t):
method between_steps (line 155) | def between_steps(self):
method __call__ (line 158) | def __call__(self, attn, is_cross, place_in_unet):
class AttentionControl (line 162) | class AttentionControl(abc.ABC):
method step_callback (line 164) | def step_callback(self, x_t):
method between_steps (line 167) | def between_steps(self):
method num_uncond_att_layers (line 171) | def num_uncond_att_layers(self):
method forward (line 175) | def forward (self, attn, is_cross, place_in_unet):
method __call__ (line 178) | def __call__(self, attn, is_cross, place_in_unet):
method reset (line 192) | def reset(self):
method __init__ (line 196) | def __init__(self):
class SpatialReplace (line 201) | class SpatialReplace(EmptyControl):
method step_callback (line 203) | def step_callback(self, x_t):
method __init__ (line 209) | def __init__(self, stop_inject,num_ddim_steps=50):
class AttentionStore (line 214) | class AttentionStore(AttentionControl):
method get_empty_store (line 217) | def get_empty_store():
method forward (line 221) | def forward(self, attn, is_cross, place_in_unet):
method between_steps (line 227) | def between_steps(self):
method get_average_attention (line 236) | def get_average_attention(self):
method reset (line 240) | def reset(self):
method __init__ (line 245) | def __init__(self):
class AttentionControlEdit (line 251) | class AttentionControlEdit(AttentionStore, abc.ABC):
method step_callback (line 253) | def step_callback(self, x_t):
method replace_self_attention (line 258) | def replace_self_attention(self, attn_base, att_replace, place_in_unet):
method replace_cross_attention (line 266) | def replace_cross_attention(self, attn_base, att_replace):
method forward (line 269) | def forward(self, attn, is_cross, place_in_unet):
method __init__ (line 284) | def __init__(self,
class AttentionReplace (line 301) | class AttentionReplace(AttentionControlEdit):
method replace_cross_attention (line 303) | def replace_cross_attention(self, attn_base, att_replace):
method __init__ (line 306) | def __init__(self, prompts, num_steps, cross_replace_steps, self_repla...
class AttentionRefine (line 317) | class AttentionRefine(AttentionControlEdit):
method replace_cross_attention (line 319) | def replace_cross_attention(self, attn_base, att_replace):
method __init__ (line 325) | def __init__(self, prompts, num_steps, cross_replace_steps, self_repla...
class AttentionReweight (line 338) | class AttentionReweight(AttentionControlEdit):
method replace_cross_attention (line 340) | def replace_cross_attention(self, attn_base, att_replace):
method __init__ (line 347) | def __init__(self,
function make_controller (line 366) | def make_controller(pipeline,
FILE: models/p2p/inversion.py
class NegativePromptInversion (line 10) | class NegativePromptInversion:
method prev_step (line 12) | def prev_step(self, model_output, timestep, sample):
method next_step (line 22) | def next_step(self, model_output, timestep, sample):
method get_noise_pred_single (line 32) | def get_noise_pred_single(self, latents, t, context):
method init_prompt (line 37) | def init_prompt(self, prompt):
method ddim_loop (line 55) | def ddim_loop(self, latent):
method scheduler (line 68) | def scheduler(self):
method ddim_inversion (line 72) | def ddim_inversion(self, image):
method invert (line 78) | def invert(self, image_gt, prompt, npi_interp=0.0):
method __init__ (line 103) | def __init__(self, model,num_ddim_steps):
class NullInversion (line 113) | class NullInversion:
method prev_step (line 115) | def prev_step(self, model_output, timestep: int, sample):
method next_step (line 125) | def next_step(self, model_output, timestep: int, sample):
method get_noise_pred_single (line 135) | def get_noise_pred_single(self, latents, t, context):
method get_noise_pred (line 139) | def get_noise_pred(self, latents, t, guidance_scale, is_forward=True, ...
method init_prompt (line 154) | def init_prompt(self, prompt: str):
method ddim_loop (line 174) | def ddim_loop(self, latent):
method scheduler (line 186) | def scheduler(self):
method ddim_inversion (line 190) | def ddim_inversion(self, image):
method null_optimization (line 196) | def null_optimization(self, latents, num_inner_steps, epsilon, guidanc...
method invert (line 227) | def invert(self, image_gt, prompt, guidance_scale, num_inner_steps=10,...
method __init__ (line 236) | def __init__(self, model,num_ddim_steps):
class DirectInversion (line 245) | class DirectInversion:
method prev_step (line 247) | def prev_step(self, model_output, timestep: int, sample):
method next_step (line 262) | def next_step(self, model_output, timestep: int, sample):
method get_noise_pred_single (line 272) | def get_noise_pred_single(self, latents, t, context):
method get_noise_pred (line 276) | def get_noise_pred(self, latents, t, guidance_scale, is_forward=True, ...
method init_prompt (line 291) | def init_prompt(self, prompt: str):
method ddim_loop (line 309) | def ddim_loop(self, latent):
method ddim_null_loop (line 322) | def ddim_null_loop(self, latent):
method ddim_with_guidance_scale_loop (line 335) | def ddim_with_guidance_scale_loop(self, latent,guidance_scale):
method scheduler (line 351) | def scheduler(self):
method ddim_inversion (line 355) | def ddim_inversion(self, image):
method ddim_null_inversion (line 362) | def ddim_null_inversion(self, image):
method ddim_with_guidance_scale_inversion (line 369) | def ddim_with_guidance_scale_inversion(self, image,guidance_scale):
method offset_calculate (line 375) | def offset_calculate(self, latents, num_inner_steps, epsilon, guidance...
method invert (line 393) | def invert(self, image_gt, prompt, guidance_scale, num_inner_steps=10,...
method invert_without_attn_controller (line 402) | def invert_without_attn_controller(self, image_gt, prompt, guidance_sc...
method invert_with_guidance_scale_vary_guidance (line 410) | def invert_with_guidance_scale_vary_guidance(self, image_gt, prompt, i...
method null_latent_calculate (line 419) | def null_latent_calculate(self, latents, num_inner_steps, epsilon, gui...
method invert_null_latent (line 463) | def invert_null_latent(self, image_gt, prompt, guidance_scale, num_inn...
method offset_calculate_not_full (line 472) | def offset_calculate_not_full(self, latents, num_inner_steps, epsilon,...
method invert_not_full (line 491) | def invert_not_full(self, image_gt, prompt, guidance_scale, num_inner_...
method offset_calculate_skip_step (line 500) | def offset_calculate_skip_step(self, latents, num_inner_steps, epsilon...
method invert_skip_step (line 522) | def invert_skip_step(self, image_gt, prompt, guidance_scale, skip_step...
method __init__ (line 532) | def __init__(self, model,num_ddim_steps):
FILE: models/p2p/p2p_guidance_forward.py
function p2p_guidance_diffusion_step (line 6) | def p2p_guidance_diffusion_step(model, controller, latents, context, t, ...
function p2p_guidance_forward (line 22) | def p2p_guidance_forward(
function p2p_guidance_forward_single_branch (line 65) | def p2p_guidance_forward_single_branch(
function direct_inversion_p2p_guidance_diffusion_step (line 103) | def direct_inversion_p2p_guidance_diffusion_step(model, controller, late...
function direct_inversion_p2p_guidance_diffusion_step_add_target (line 119) | def direct_inversion_p2p_guidance_diffusion_step_add_target(model, contr...
function direct_inversion_p2p_guidance_forward (line 136) | def direct_inversion_p2p_guidance_forward(
function direct_inversion_p2p_guidance_forward_add_target (line 176) | def direct_inversion_p2p_guidance_forward_add_target(
FILE: models/p2p/proximal_guidance_forward.py
function dilate (line 7) | def dilate(image, kernel_size, stride=1, padding=0):
function proximal_guidance_diffusion_step (line 19) | def proximal_guidance_diffusion_step(model, controller, latents, context...
function proximal_guidance_forward (line 86) | def proximal_guidance_forward(
FILE: models/p2p/scheduler_dev.py
class DDIMSchedulerDev (line 8) | class DDIMSchedulerDev(DDIMScheduler):
method step (line 10) | def step(
FILE: models/p2p/seq_aligner.py
class ScoreParams (line 18) | class ScoreParams:
method __init__ (line 20) | def __init__(self, gap, match, mismatch):
method mis_match_char (line 25) | def mis_match_char(self, x, y):
function get_matrix (line 32) | def get_matrix(size_x, size_y, gap):
function get_matrix (line 46) | def get_matrix(size_x, size_y, gap):
function get_traceback_matrix (line 53) | def get_traceback_matrix(size_x, size_y):
function global_align (line 61) | def global_align(x, y, score):
function get_aligned_sequences (line 79) | def get_aligned_sequences(x, y, trace_back):
function get_mapper (line 107) | def get_mapper(x, y, tokenizer, max_len=77):
function get_refinement_mapper (line 121) | def get_refinement_mapper(prompts, tokenizer, max_len=77):
function get_word_inds (line 131) | def get_word_inds(text, word_place, tokenizer):
function get_replacement_mapper_ (line 152) | def get_replacement_mapper_(x, y, tokenizer, max_len=77):
function get_replacement_mapper (line 189) | def get_replacement_mapper(prompts, tokenizer, max_len=77):
FILE: models/p2p_editor.py
class P2PEditor (line 12) | class P2PEditor:
method __init__ (line 13) | def __init__(self, method_list, device, num_ddim_steps=50) -> None:
method __call__ (line 28) | def __call__(self,
method edit_image_ddim (line 137) | def edit_image_ddim(
method edit_image_null_text_inversion (line 199) | def edit_image_null_text_inversion(
method edit_image_null_text_inversion_single_branch (line 261) | def edit_image_null_text_inversion_single_branch(
method edit_image_negative_prompt_inversion (line 324) | def edit_image_negative_prompt_inversion(
method edit_image_directinversion (line 415) | def edit_image_directinversion(
method edit_image_directinversion_vary_guidance_scale (line 481) | def edit_image_directinversion_vary_guidance_scale(
method edit_image_null_text_inversion_proximal_guidanca (line 550) | def edit_image_null_text_inversion_proximal_guidanca(
method edit_image_null_latent_inversion (line 640) | def edit_image_null_latent_inversion(
method edit_image_directinversion_not_full (line 707) | def edit_image_directinversion_not_full(
method edit_image_directinversion_skip_step (line 775) | def edit_image_directinversion_skip_step(
method edit_image_directinversion_add_target (line 842) | def edit_image_directinversion_add_target(
method edit_image_directinversion_add_source (line 909) | def edit_image_directinversion_add_source(
FILE: models/pix2pix_zero/base_pipeline.py
class BasePipeline (line 17) | class BasePipeline(DiffusionPipeline):
method __init__ (line 19) | def __init__(
method _execution_device (line 110) | def _execution_device(self):
method _encode_prompt (line 129) | def _encode_prompt(
method decode_latents (line 269) | def decode_latents(self, latents):
method prepare_latents (line 277) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
method prepare_extra_step_kwargs (line 295) | def prepare_extra_step_kwargs(self, generator, eta):
method run_safety_checker (line 313) | def run_safety_checker(self, image, device, dtype):
FILE: models/pix2pix_zero/cross_attention.py
class MyCrossAttnProcessor (line 4) | class MyCrossAttnProcessor:
method __call__ (line 5) | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden...
function prep_unet (line 45) | def prep_unet(unet):
FILE: models/pix2pix_zero/ddim_inv.py
class DDIMInversion (line 21) | class DDIMInversion(BasePipeline):
method auto_corr_loss (line 23) | def auto_corr_loss(self, x, random_shift=True):
method kl_divergence (line 41) | def kl_divergence(self, x):
method __call__
Copy disabled (too large)
Download .json
Condensed preview — 339 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (10,250K chars).
[
{
"path": ".gitignore",
"chars": 50,
"preview": "data\n__pycache__\n.vscode\noutput\n*.csv\n*.out\n*.bash"
},
{
"path": "README.md",
"chars": 18962,
"preview": "# PnPInversion\n\n\nThis repository contains the implementation of the ICLR2024 paper \"PnP Inversion: Boosting Diffusion-ba"
},
{
"path": "environment/edict_requirements.txt",
"chars": 66,
"preview": "diffusers==0.6.0\ntransformers==4.19.2\nmatplotlib\nomegaconf\nimageio"
},
{
"path": "environment/instructdiffusion_requirements.txt",
"chars": 147,
"preview": "einops==0.6.1\ntaming-transformers-rom1504==0.0.6\nomegaconf==2.3.0\nk-diffusion==0.0.16\ndeepspeed==0.10.2\ntimm==0.9.7\ntran"
},
{
"path": "environment/masactrl_requirements.txt",
"chars": 90,
"preview": "diffusers==0.15.0\ntransformers\nopencv-python\neinops\nomegaconf\npytorch_lightning\nmatplotlib"
},
{
"path": "environment/p2p_requirements.txt",
"chars": 82,
"preview": "diffusers==0.10.0\ntransformers\nftfy\nopencv-python\nipywidgets\nmatplotlib\naccelerate"
},
{
"path": "environment/pix2pix_zero_requirements.txt",
"chars": 46,
"preview": "diffusers==0.14.0\nmatplotlib\nsalesforce-lavis\n"
},
{
"path": "environment/pnp_requirements.txt",
"chars": 102,
"preview": "diffusers==0.17.1\nxformers==0.0.20\ntransformers==4.30.2\naccelerate==0.20.3\nmatplotlib\nsalesforce-lavis"
},
{
"path": "evaluation/evaluate.py",
"chars": 16358,
"preview": "import json\nimport argparse\nimport os\nimport numpy as np\nfrom PIL import Image\nimport csv\nfrom evaluation.matrics_calcul"
},
{
"path": "evaluation/matrics_calculator.py",
"chars": 17583,
"preview": "import torch\nfrom torchvision.transforms import Resize\nfrom torchvision import transforms\nimport torch.nn.functional as "
},
{
"path": "models/InstructDiffusion/.gitignore",
"chars": 3411,
"preview": "data/\ncheckpoints/\nstable_diffusion/models/ldm/stable-diffusion-v1/\nsrc/\nlogs/\ncache/\nimgs/*\nwork_dirs/*\nwandb/*\nDeepSpe"
},
{
"path": "models/InstructDiffusion/LICENSE",
"chars": 1655,
"preview": "Copyright 2023 Authors of InstructDiffusion(https://arxiv.org/pdf/2309.03895.pdf)\n\nPermission is hereby granted, free of"
},
{
"path": "models/InstructDiffusion/README.md",
"chars": 5117,
"preview": "# InstructDiffusion: A Generalist Modeling Interface for Vision Tasks\n\n<p align=\"center\">\n <a href=\"https://gengzigang."
},
{
"path": "models/InstructDiffusion/configs/instruct_diffusion.yaml",
"chars": 6774,
"preview": "# File modified by authors of InstructDiffusion from original (https://github.com/CompVis/stable-diffusion).\n# See more "
},
{
"path": "models/InstructDiffusion/dataset/README.md",
"chars": 2338,
"preview": "You can download these datasets: [COCO](http://cocodataset.org/#download), [CrowdPose](https://github.com/Jeff-sjtu/Crow"
},
{
"path": "models/InstructDiffusion/dataset/editing/edit_zip_dataset.py",
"chars": 19930,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/low_level/lowlevel_clwd.py",
"chars": 4226,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/low_level/lowlevel_gopro.py",
"chars": 4179,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/low_level/lowlevel_reds.py",
"chars": 4644,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/low_level/lowlevel_sidd.py",
"chars": 3719,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/pose/pose.py",
"chars": 26174,
"preview": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed un"
},
{
"path": "models/InstructDiffusion/dataset/prompt/color_list_train_small.txt",
"chars": 217,
"preview": "Red 纯红 #FF0000 255,0,0\n\nPurple 紫色 #800080 128,0,128\n\nBlue 纯蓝 #0000FF 0,0,255\n\nGreen 纯绿 #008000 0,128,0\n\nYellow 纯黄 #FFFF0"
},
{
"path": "models/InstructDiffusion/dataset/prompt/prompt_deblur.txt",
"chars": 408,
"preview": "Sharpen this blurry image\nIncrease the sharpness of this unclear photo\nCorrect the lack of focus in this misty picture\nH"
},
{
"path": "models/InstructDiffusion/dataset/prompt/prompt_denoise.txt",
"chars": 400,
"preview": "Remove noise from this image\nEliminate the noise in this picture\nPurify this photo by removing noise\nClear up the image "
},
{
"path": "models/InstructDiffusion/dataset/prompt/prompt_dewatermark.txt",
"chars": 451,
"preview": "Remove watermark from this picture\nErase the watermark from this photograph.\nExtract the watermark from this image.\nTake"
},
{
"path": "models/InstructDiffusion/dataset/prompt/prompt_pose.txt",
"chars": 592,
"preview": "Circle the {joint} of the people with the color {color}, \nUse the {color} color to draw circles around the {joint} of th"
},
{
"path": "models/InstructDiffusion/dataset/prompt/prompt_seg.txt",
"chars": 922,
"preview": "Mark the pixels of {object} in {color} and leave the rest unchanged.\nColor the {object}'s pixels in {color}, keeping the"
},
{
"path": "models/InstructDiffusion/dataset/seg/coco_stuff.py",
"chars": 6501,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/seg/grefcoco.py",
"chars": 12435,
"preview": "\"\"\"\ngrefer v0.1\nThis interface provides access to gRefCOCO.\n\nThe following API functions are defined:\nG_REFER - REF"
},
{
"path": "models/InstructDiffusion/dataset/seg/grefcoco_segmentation.py",
"chars": 5562,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/seg/refcoco.py",
"chars": 12433,
"preview": "__author__ = 'licheng'\n\n\"\"\"\nThis interface provides access to four datasets:\n1) refclef\n2) refcoco\n3) refcoco+\n4) refcoc"
},
{
"path": "models/InstructDiffusion/dataset/seg/refcoco_segmentation.py",
"chars": 5540,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/dataset/utils/zip_manager.py",
"chars": 3479,
"preview": "import zipfile\nimport os.path as osp\n# import lmdb\nimport logging\nfrom PIL import Image\nimport pickle\nimport io\nimport g"
},
{
"path": "models/InstructDiffusion/edit_app.py",
"chars": 14212,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/edit_cli.py",
"chars": 5229,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/environment.yaml",
"chars": 1017,
"preview": "# File modified by authors of InstructDiffusion from original (https://github.com/CompVis/stable-diffusion).\n# See more "
},
{
"path": "models/InstructDiffusion/main.py",
"chars": 20814,
"preview": "# --------------------------------------------------------\n# InstructDiffusion\n# Based on instruct-pix2pix (https://gith"
},
{
"path": "models/InstructDiffusion/scripts/convert_ckpt.py",
"chars": 1988,
"preview": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed un"
},
{
"path": "models/InstructDiffusion/scripts/download_pretrained_sd.sh",
"chars": 570,
"preview": "#!/bin/bash\n\nSCRIPT_DIR=$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\n\nmkdir -p $SCRIPT_DIR/../st"
},
{
"path": "models/InstructDiffusion/scripts/inference_example.sh",
"chars": 1891,
"preview": "# Example: Image Editing\npython edit_cli.py --input figure/animals.png --edit \"Transform it to van Gogh, starry night st"
},
{
"path": "models/InstructDiffusion/scripts/run_multinode.sh",
"chars": 282,
"preview": "EXP=$1\nNAME=$2\nGPUMUM=$3\nset -x \n\npython -m torch.distributed.launch --nnodes=${GPUMUM} --nproc_per_node=8 --node_rank=$"
},
{
"path": "models/InstructDiffusion/stable_diffusion/LICENSE",
"chars": 14381,
"preview": "Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors\n\nCreativeML Open RAIL-M\ndated August 22, 2022\n\nSecti"
},
{
"path": "models/InstructDiffusion/stable_diffusion/README.md",
"chars": 12439,
"preview": "# Stable Diffusion\n*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.a"
},
{
"path": "models/InstructDiffusion/stable_diffusion/Stable_Diffusion_v1_Model_Card.md",
"chars": 9340,
"preview": "# Stable Diffusion v1 Model Card\nThis model card focuses on the model associated with the Stable Diffusion model, availa"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/results.gif.REMOVED.git-id",
"chars": 40,
"preview": "82b6590e670a32196093cc6333ea19e6547d07de"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/stable-samples/img2img/upscaling-in.png.REMOVED.git-id",
"chars": 40,
"preview": "501c31c21751664957e69ce52cad1818b6d2f4ce"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/stable-samples/img2img/upscaling-out.png.REMOVED.git-id",
"chars": 40,
"preview": "1c4bb25a779f34d86b2d90e584ac67af91bb1303"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/stable-samples/txt2img/merged-0005.png.REMOVED.git-id",
"chars": 40,
"preview": "ca0a1af206555f0f208a1ab879e95efedc1b1c5b"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/stable-samples/txt2img/merged-0006.png.REMOVED.git-id",
"chars": 40,
"preview": "999f3703230580e8c89e9081abd6a1f8f50896d4"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/stable-samples/txt2img/merged-0007.png.REMOVED.git-id",
"chars": 40,
"preview": "af390acaf601283782d6f479d4cade4d78e30b26"
},
{
"path": "models/InstructDiffusion/stable_diffusion/assets/txt2img-preview.png.REMOVED.git-id",
"chars": 40,
"preview": "51ee1c235dfdc63d4c41de7d303d03730e43c33c"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml",
"chars": 1145,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml",
"chars": 1140,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml",
"chars": 1139,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml",
"chars": 1148,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml",
"chars": 2028,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml",
"chars": 2360,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/cin256-v2.yaml",
"chars": 1553,
"preview": "model:\n base_learning_rate: 0.0001\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.00"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml",
"chars": 2020,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml",
"chars": 2024,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml",
"chars": 2284,
"preview": "model:\n base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'\n target: ldm.model"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml",
"chars": 1614,
"preview": "model:\n base_learning_rate: 5.0e-05\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/retrieval-augmented-diffusion/768x768.yaml",
"chars": 1615,
"preview": "model:\n base_learning_rate: 0.0001\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.00"
},
{
"path": "models/InstructDiffusion/stable_diffusion/configs/stable-diffusion/v1-inference.yaml",
"chars": 1873,
"preview": "model:\n base_learning_rate: 1.0e-04\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/environment.yaml",
"chars": 734,
"preview": "name: ldm\nchannels:\n - pytorch\n - defaults\ndependencies:\n - python=3.8.5\n - pip=20.3\n - cudatoolkit=11.3\n - pytorc"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/lr_scheduler.py",
"chars": 3882,
"preview": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n \"\"\"\n note: use with a base_lr of 1.0\n \"\"\"\n def __in"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/autoencoder.py",
"chars": 17925,
"preview": "# --------------------------------------------------------\n# Stable-Diffusion-Torch\n# Based on Stable-Diffusion (https:/"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/classifier.py",
"chars": 10276,
"preview": "import os\nimport torch\nimport pytorch_lightning as pl\nfrom omegaconf import OmegaConf\nfrom torch.nn import functional as"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/ddim.py",
"chars": 12797,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/ddpm.py",
"chars": 67425,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/ddpm_edit.py",
"chars": 57786,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/dpm_solver/__init__.py",
"chars": 37,
"preview": "from .sampler import DPMSolverSampler"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py",
"chars": 64057,
"preview": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass NoiseScheduleVP:\n def __init__(\n self,\n "
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/dpm_solver/sampler.py",
"chars": 2908,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\n\nfrom .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver\n\n\nclass DPMSolver"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/models/diffusion/plms.py",
"chars": 12450,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/attention.py",
"chars": 14776,
"preview": "# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).\n# See more de"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/model.py",
"chars": 35868,
"preview": "# pytorch_diffusion + derived encoder decoder\n# Removed Pytorch-lightning by Zigang Geng (zigang@mail.ustc.edu.cn)\n\nimpo"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py",
"chars": 38372,
"preview": "# --------------------------------------------------------\n# Stable-Diffusion-Torch\n# Based on Stable-Diffusion (https:/"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/diffusionmodules/util.py",
"chars": 10331,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/distributions/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/distributions/distributions.py",
"chars": 2970,
"preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n def sample(self):\n raise NotImplementedError()\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/ema.py",
"chars": 3730,
"preview": "# --------------------------------------------------------\n# Stable-Diffusion-Torch\n# Based on Stable-Diffusion (https:/"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/encoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/encoders/modules.py",
"chars": 8154,
"preview": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom t"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/__init__.py",
"chars": 208,
"preview": "from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ldm.modules.image"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/bsrgan.py",
"chars": 25198,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# ------------------------"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py",
"chars": 22238,
"preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom sci"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/image_degradation/utils_image.py",
"chars": 29022,
"preview": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nf"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/losses/__init__.py",
"chars": 68,
"preview": "from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/losses/contperceptual.py",
"chars": 5581,
"preview": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/losses/vqperceptual.py",
"chars": 7941,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discrim"
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/modules/x_transformer.py",
"chars": 20168,
"preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import "
},
{
"path": "models/InstructDiffusion/stable_diffusion/ldm/util.py",
"chars": 5857,
"preview": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functool"
},
{
"path": "models/InstructDiffusion/stable_diffusion/main.py",
"chars": 28229,
"preview": "import argparse, os, sys, datetime, glob, importlib, csv\nimport numpy as np\nimport time\nimport torch\nimport torchvision\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/kl-f16/config.yaml",
"chars": 909,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/kl-f32/config.yaml",
"chars": 929,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/kl-f4/config.yaml",
"chars": 880,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/kl-f8/config.yaml",
"chars": 889,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/vq-f16/config.yaml",
"chars": 1026,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 8\n n_embed: 16"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/vq-f4/config.yaml",
"chars": 955,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 3\n n_embed: 81"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/vq-f4-noattn/config.yaml",
"chars": 978,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 3\n n_embed: 81"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/vq-f8/config.yaml",
"chars": 1035,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 4\n n_embed: 16"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/first_stage_models/vq-f8-n256/config.yaml",
"chars": 1013,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 4\n n_embed: 25"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/bsr_sr/config.yaml",
"chars": 1900,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/celeba256/config.yaml",
"chars": 1599,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/cin256/config.yaml",
"chars": 1862,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/ffhq256/config.yaml",
"chars": 1591,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/inpainting_big/config.yaml",
"chars": 1619,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/layout2img-openimages256/config.yaml",
"chars": 1924,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/lsun_beds256/config.yaml",
"chars": 1601,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/lsun_churches256/config.yaml",
"chars": 2018,
"preview": "model:\n base_learning_rate: 5.0e-05\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/semantic_synthesis256/config.yaml",
"chars": 1378,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/semantic_synthesis512/config.yaml",
"chars": 1820,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/models/ldm/text2img256/config.yaml",
"chars": 1831,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/InstructDiffusion/stable_diffusion/notebook_helpers.py",
"chars": 10099,
"preview": "from torchvision.datasets.utils import download_url\nfrom ldm.util import instantiate_from_config\nimport torch\nimport os\n"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/download_first_stages.sh",
"chars": 1324,
"preview": "#!/bin/bash\nwget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip\nwge"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/download_models.sh",
"chars": 1681,
"preview": "#!/bin/bash\nwget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip\nwget -O "
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/img2img.py",
"chars": 9181,
"preview": "\"\"\"make variations of input image\"\"\"\n\nimport argparse, os, sys, glob\nimport PIL\nimport torch\nimport numpy as np\nfrom ome"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/inpaint.py",
"chars": 3644,
"preview": "import argparse, os, sys, glob\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfrom tqdm import tqdm\nimport numpy "
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/knn2img.py",
"chars": 13707,
"preview": "import argparse, os, sys, glob\nimport clip\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom omegaconf import O"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/latent_imagenet_diffusion.ipynb.REMOVED.git-id",
"chars": 40,
"preview": "607f94fc7d3ef6d8d1627017215476d9dfc7ddc4"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/sample_diffusion.py",
"chars": 9606,
"preview": "import argparse, os, sys, glob, datetime, yaml\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import trange\n\nfrom"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/tests/test_watermark.py",
"chars": 357,
"preview": "import cv2\nimport fire\nfrom imwatermark import WatermarkDecoder\n\n\ndef testit(img_path):\n bgr = cv2.imread(img_path)\n "
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/train_searcher.py",
"chars": 5807,
"preview": "import os, sys\nimport numpy as np\nimport scann\nimport argparse\nimport glob\nfrom multiprocessing import cpu_count\nfrom tq"
},
{
"path": "models/InstructDiffusion/stable_diffusion/scripts/txt2img.py",
"chars": 11666,
"preview": "import argparse, os, sys, glob\nimport cv2\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL impor"
},
{
"path": "models/InstructDiffusion/stable_diffusion/setup.py",
"chars": 233,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name='latent-diffusion',\n version='0.0.1',\n description=''"
},
{
"path": "models/InstructDiffusion/utils/deepspeed.py",
"chars": 2104,
"preview": "import os\nimport torch\nimport torch.distributed as dist\nimport json\n\n\ndef create_ds_config(args, config, cfgdir):\n co"
},
{
"path": "models/InstructDiffusion/utils/logger.py",
"chars": 1451,
"preview": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed "
},
{
"path": "models/InstructDiffusion/utils/utils.py",
"chars": 7327,
"preview": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed "
},
{
"path": "models/edict/edict_functions.py",
"chars": 41172,
"preview": "import torch\nfrom transformers import CLIPModel, CLIPTextModel, CLIPTokenizer\nfrom omegaconf import OmegaConf\nimport mat"
},
{
"path": "models/edict/my_diffusers/__init__.py",
"chars": 1672,
"preview": "from .utils import (\n is_inflect_available,\n is_onnx_available,\n is_scipy_available,\n is_transformers_availa"
},
{
"path": "models/edict/my_diffusers/commands/__init__.py",
"chars": 920,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/commands/diffusers_cli.py",
"chars": 1200,
"preview": "#!/usr/bin/env python\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License,"
},
{
"path": "models/edict/my_diffusers/commands/env.py",
"chars": 2384,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/configuration_utils.py",
"chars": 18472,
"preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserve"
},
{
"path": "models/edict/my_diffusers/dependency_versions_check.py",
"chars": 1756,
"preview": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/dependency_versions_table.py",
"chars": 836,
"preview": "# THIS FILE HAS BEEN AUTOGENERATED. To update:\n# 1. modify the `_deps` dict in setup.py\n# 2. run `make deps_table_update"
},
{
"path": "models/edict/my_diffusers/dynamic_modules_utils.py",
"chars": 14206,
"preview": "# coding=utf-8\n# Copyright 2021 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
},
{
"path": "models/edict/my_diffusers/hub_utils.py",
"chars": 7602,
"preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
},
{
"path": "models/edict/my_diffusers/modeling_utils.py",
"chars": 25059,
"preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserve"
},
{
"path": "models/edict/my_diffusers/models/__init__.py",
"chars": 732,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/models/attention.py",
"chars": 13351,
"preview": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass Atten"
},
{
"path": "models/edict/my_diffusers/models/embeddings.py",
"chars": 3895,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/models/resnet.py",
"chars": 18510,
"preview": "from functools import partial\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nc"
},
{
"path": "models/edict/my_diffusers/models/unet_2d.py",
"chars": 10715,
"preview": "from dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom .."
},
{
"path": "models/edict/my_diffusers/models/unet_2d_condition.py",
"chars": 12180,
"preview": "from dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom .."
},
{
"path": "models/edict/my_diffusers/models/unet_blocks.py",
"chars": 52429,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/models/vae.py",
"chars": 20989,
"preview": "from dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torc"
},
{
"path": "models/edict/my_diffusers/onnx_utils.py",
"chars": 7314,
"preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserve"
},
{
"path": "models/edict/my_diffusers/optimization.py",
"chars": 11363,
"preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lice"
},
{
"path": "models/edict/my_diffusers/pipeline_utils.py",
"chars": 18785,
"preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserve"
},
{
"path": "models/edict/my_diffusers/pipelines/__init__.py",
"chars": 668,
"preview": "from ..utils import is_onnx_available, is_transformers_available\nfrom .ddim import DDIMPipeline\nfrom .ddpm import DDPMPi"
},
{
"path": "models/edict/my_diffusers/pipelines/ddim/__init__.py",
"chars": 55,
"preview": "# flake8: noqa\nfrom .pipeline_ddim import DDIMPipeline\n"
},
{
"path": "models/edict/my_diffusers/pipelines/ddim/pipeline_ddim.py",
"chars": 4844,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/pipelines/ddpm/__init__.py",
"chars": 55,
"preview": "# flake8: noqa\nfrom .pipeline_ddpm import DDPMPipeline\n"
},
{
"path": "models/edict/my_diffusers/pipelines/ddpm/pipeline_ddpm.py",
"chars": 4247,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/pipelines/latent_diffusion/__init__.py",
"chars": 176,
"preview": "# flake8: noqa\nfrom ...utils import is_transformers_available\n\n\nif is_transformers_available():\n from .pipeline_laten"
},
{
"path": "models/edict/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py",
"chars": 30846,
"preview": "import inspect\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimpor"
},
{
"path": "models/edict/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py",
"chars": 73,
"preview": "# flake8: noqa\nfrom .pipeline_latent_diffusion_uncond import LDMPipeline\n"
},
{
"path": "models/edict/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py",
"chars": 4520,
"preview": "import inspect\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport torch\n\nfrom ...models import UNet2DMode"
},
{
"path": "models/edict/my_diffusers/pipelines/pndm/__init__.py",
"chars": 55,
"preview": "# flake8: noqa\nfrom .pipeline_pndm import PNDMPipeline\n"
},
{
"path": "models/edict/my_diffusers/pipelines/pndm/pipeline_pndm.py",
"chars": 4612,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/pipelines/score_sde_ve/__init__.py",
"chars": 69,
"preview": "# flake8: noqa\nfrom .pipeline_score_sde_ve import ScoreSdeVePipeline\n"
},
{
"path": "models/edict/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py",
"chars": 4210,
"preview": "#!/usr/bin/env python3\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport torch\n\nfrom ...models import UN"
},
{
"path": "models/edict/my_diffusers/pipelines/stable_diffusion/__init__.py",
"chars": 1357,
"preview": "from dataclasses import dataclass\nfrom typing import List, Union\n\nimport numpy as np\n\nimport PIL\nfrom PIL import Image\n\n"
},
{
"path": "models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py",
"chars": 14141,
"preview": "import inspect\nimport warnings\nfrom typing import List, Optional, Union\n\nimport torch\n\nfrom transformers import CLIPFeat"
},
{
"path": "models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py",
"chars": 14908,
"preview": "import inspect\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\n\nimport PIL\nfrom transformers i"
},
{
"path": "models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py",
"chars": 15478,
"preview": "import inspect\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\n\nimport PIL\nfrom tqdm.auto impo"
},
{
"path": "models/edict/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py",
"chars": 7192,
"preview": "import inspect\nfrom typing import List, Optional, Union\n\nimport numpy as np\n\nfrom transformers import CLIPFeatureExtract"
},
{
"path": "models/edict/my_diffusers/pipelines/stable_diffusion/safety_checker.py",
"chars": 4735,
"preview": "import numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom transformers import CLIPConfig, CLIPVisionModel, PreTrainedM"
},
{
"path": "models/edict/my_diffusers/pipelines/stochastic_karras_ve/__init__.py",
"chars": 75,
"preview": "# flake8: noqa\nfrom .pipeline_stochastic_karras_ve import KarrasVePipeline\n"
},
{
"path": "models/edict/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py",
"chars": 5581,
"preview": "#!/usr/bin/env python3\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport torch\n\nfrom ...models import UN"
},
{
"path": "models/edict/my_diffusers/schedulers/__init__.py",
"chars": 1128,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_ddim.py",
"chars": 12196,
"preview": "# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache L"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_ddpm.py",
"chars": 11667,
"preview": "# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, V"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_karras_ve.py",
"chars": 8927,
"preview": "# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2."
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_lms_discrete.py",
"chars": 8064,
"preview": "# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License,"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_pndm.py",
"chars": 16983,
"preview": "# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache L"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_sde_ve.py",
"chars": 13179,
"preview": "# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Vers"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_sde_vp.py",
"chars": 3118,
"preview": "# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Vers"
},
{
"path": "models/edict/my_diffusers/schedulers/scheduling_utils.py",
"chars": 4451,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/testing_utils.py",
"chars": 1651,
"preview": "import os\nimport random\nimport unittest\nfrom distutils.util import strtobool\n\nimport torch\n\nfrom packaging import versio"
},
{
"path": "models/edict/my_diffusers/training_utils.py",
"chars": 4110,
"preview": "import copy\nimport os\nimport random\n\nimport numpy as np\nimport torch\n\n\ndef enable_full_determinism(seed: int):\n \"\"\"\n "
},
{
"path": "models/edict/my_diffusers/utils/__init__.py",
"chars": 1584,
"preview": "# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the"
},
{
"path": "models/edict/my_diffusers/utils/dummy_scipy_objects.py",
"chars": 307,
"preview": "# This file is autogenerated by the command `make fix-copies`, do not edit.\n# flake8: noqa\n\nfrom ..utils import DummyObj"
},
{
"path": "models/edict/my_diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py",
"chars": 363,
"preview": "# This file is autogenerated by the command `make fix-copies`, do not edit.\n# flake8: noqa\nfrom ..utils import DummyObje"
},
{
"path": "models/edict/my_diffusers/utils/dummy_transformers_and_onnx_objects.py",
"chars": 344,
"preview": "# This file is autogenerated by the command `make fix-copies`, do not edit.\n# flake8: noqa\n\nfrom ..utils import DummyObj"
},
{
"path": "models/edict/my_diffusers/utils/dummy_transformers_objects.py",
"chars": 880,
"preview": "# This file is autogenerated by the command `make fix-copies`, do not edit.\n# flake8: noqa\n\nfrom ..utils import DummyObj"
},
{
"path": "models/edict/my_diffusers/utils/import_utils.py",
"chars": 9242,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edict/my_diffusers/utils/logging.py",
"chars": 9411,
"preview": "# coding=utf-8\n# Copyright 2020 Optuna, Hugging Face\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
},
{
"path": "models/edict/my_diffusers/utils/model_card_template.md",
"chars": 1478,
"preview": "---\n{{ card_data }}\n---\n\n<!-- This model card has been generated automatically according to the information the training"
},
{
"path": "models/edict/my_diffusers/utils/outputs.py",
"chars": 3744,
"preview": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "models/edit_friendly_ddm/inversion_utils.py",
"chars": 11215,
"preview": "import torch\nimport os\n\ndef load_real_image(folder = \"data/\", img_name = None, idx = 0, img_size=512, device='cuda'):\n "
},
{
"path": "models/edit_friendly_ddm/ptp_classes.py",
"chars": 12923,
"preview": "\"\"\"\nThis code was originally taken from\nhttps://github.com/google/prompt-to-prompt\n\"\"\"\n\n\nLOW_RESOURCE = True \nMAX_NUM_WO"
},
{
"path": "models/edit_friendly_ddm/ptp_utils.py",
"chars": 11514,
"preview": "\"\"\"\nThis code was originally taken from\nhttps://github.com/google/prompt-to-prompt\n\"\"\"\n\n# Copyright 2022 Google LLC\n#\n# "
},
{
"path": "models/edit_friendly_ddm/seq_aligner.py",
"chars": 6650,
"preview": "\"\"\"\nThis code was originally taken from\nhttps://github.com/google/prompt-to-prompt\n\"\"\"\n\n# Copyright 2022 Google LLC\n#\n# "
},
{
"path": "models/instructpix2pix/LICENSE",
"chars": 1499,
"preview": "Copyright 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros\n\nPermission is hereby granted, free of charge, to an"
},
{
"path": "models/instructpix2pix/README.md",
"chars": 16760,
"preview": "# InstructPix2Pix: Learning to Follow Image Editing Instructions\n### [Project Page](https://www.timothybrooks.com/instru"
}
]
// ... and 139 more files (download for full content)
About this extraction
This page contains the full source code of the cure-lab/PnPInversion GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 339 files (9.7 MB), approximately 2.6M tokens, and a symbol index with 2669 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.