Repository: vladmandic/sdnext Branch: master Commit: c9e21a51db77 Files: 1302 Total size: 30.7 MB Directory structure: gitextract_kfv17jig/ ├── .dockerignore ├── .gitignore ├── .gitmodules ├── .pylintrc ├── .ruff.toml ├── CHANGELOG.md ├── CITATION.cff ├── CODE_OF_CONDUCT ├── CONTRIBUTING ├── README.md ├── SECURITY.md ├── TODO.md ├── cli/ │ ├── api-checkpoint.py │ ├── api-control.py │ ├── api-detect.py │ ├── api-enhance.py │ ├── api-faceid.py │ ├── api-grid.py │ ├── api-history.py │ ├── api-img2img.py │ ├── api-info.py │ ├── api-interrogate.py │ ├── api-json.py │ ├── api-mask.py │ ├── api-model.js │ ├── api-preprocess.py │ ├── api-progress.py │ ├── api-pulid.js │ ├── api-samplers.py │ ├── api-txt2img.js │ ├── api-txt2img.py │ ├── api-upscale.py │ ├── api-vqa.py │ ├── api-xyz.py │ ├── api-xyzenum.py │ ├── civitai-search.py │ ├── download-file.py │ ├── full-test.sh │ ├── gen-styles.py │ ├── generate-random.json │ ├── generate.json │ ├── generate.py │ ├── git-clone.py │ ├── hf-search.py │ ├── image-encode.py │ ├── image-exif.py │ ├── image-grid.py │ ├── image-palette.py │ ├── image-search.py │ ├── image-watermark.py │ ├── install-stablefast.py │ ├── lcm-convert.py │ ├── load-unet.py │ ├── locale-sanitize-override.py │ ├── localize.js │ ├── model-keys.py │ ├── model-metadata.py │ ├── nvidia-smi.py │ ├── process.py │ ├── process_options.py │ ├── requirements.txt │ ├── run-benchmark.py │ ├── sdapi.py │ ├── search-docs.py │ ├── test-schedulers.py │ ├── test-tagger.py │ ├── test-weighted-lists.py │ ├── util.py │ ├── validate-locale.py │ ├── video-extract.py │ └── zluda-python.py ├── configs/ │ ├── Dockerfile.cuda │ ├── Dockerfile.ipex │ ├── Dockerfile.openvino │ ├── Dockerfile.rocm │ ├── chroma/ │ │ ├── model_index.json │ │ ├── scheduler/ │ │ │ └── scheduler_config.json │ │ ├── text_encoder/ │ │ │ └── config.json │ │ ├── tokenizer/ │ │ │ ├── added_tokens.json │ │ │ ├── special_tokens_map.json │ │ │ ├── spiece.model │ │ │ └── tokenizer_config.json │ │ ├── transformer/ │ │ │ ├── config.json │ │ │ └── diffusion_pytorch_model.safetensors.index.json │ │ └── vae/ │ │ └── config.json │ ├── flux/ │ │ ├── model_index.json │ │ ├── scheduler/ │ │ │ └── scheduler_config.json │ │ ├── text_encoder/ │ │ │ └── config.json │ │ ├── text_encoder_2/ │ │ │ ├── config.json │ │ │ └── model.safetensors.index.json │ │ ├── tokenizer/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── tokenizer_2/ │ │ │ ├── special_tokens_map.json │ │ │ ├── spiece.model │ │ │ ├── tokenizer.json │ │ │ └── tokenizer_config.json │ │ ├── transformer/ │ │ │ ├── config.json │ │ │ └── diffusion_pytorch_model.safetensors.index.json │ │ └── vae/ │ │ └── config.json │ ├── olive/ │ │ ├── sd/ │ │ │ ├── text_encoder.json │ │ │ ├── unet.json │ │ │ ├── vae_decoder.json │ │ │ └── vae_encoder.json │ │ └── sdxl/ │ │ ├── text_encoder.json │ │ ├── text_encoder_2.json │ │ ├── unet.json │ │ ├── vae_decoder.json │ │ └── vae_encoder.json │ ├── playground-v2.5-1024px-aesthetic.fp16_vae.json │ ├── sd15/ │ │ ├── feature_extractor/ │ │ │ └── preprocessor_config.json │ │ ├── model_index.json │ │ ├── safety_checker/ │ │ │ └── config.json │ │ ├── scheduler/ │ │ │ └── scheduler_config.json │ │ ├── text_encoder/ │ │ │ └── config.json │ │ ├── tokenizer/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── unet/ │ │ │ └── config.json │ │ └── vae/ │ │ └── config.json │ ├── sd3/ │ │ ├── model_index.json │ │ ├── scheduler/ │ │ │ └── scheduler_config.json │ │ ├── text_encoder/ │ │ │ └── config.json │ │ ├── text_encoder_2/ │ │ │ └── config.json │ │ ├── text_encoder_3/ │ │ │ └── config.json │ │ ├── tokenizer/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── tokenizer_2/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── tokenizer_3/ │ │ │ ├── special_tokens_map.json │ │ │ ├── spiece.model │ │ │ ├── tokenizer.json │ │ │ └── tokenizer_config.json │ │ ├── transformer/ │ │ │ └── config.json │ │ └── vae/ │ │ └── config.json │ ├── sdxl/ │ │ ├── model_index.json │ │ ├── scheduler/ │ │ │ └── scheduler_config.json │ │ ├── text_encoder/ │ │ │ └── config.json │ │ ├── text_encoder_2/ │ │ │ └── config.json │ │ ├── tokenizer/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── tokenizer_2/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── unet/ │ │ │ └── config.json │ │ └── vae/ │ │ └── config.json │ ├── sdxl-refiner/ │ │ ├── model_index.json │ │ ├── scheduler/ │ │ │ └── scheduler_config.json │ │ ├── text_encoder_2/ │ │ │ └── config.json │ │ ├── tokenizer_2/ │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── unet/ │ │ │ └── config.json │ │ └── vae/ │ │ └── config.json │ └── stable-cascade/ │ ├── prior/ │ │ └── config.json │ └── prior_lite/ │ └── config.json ├── data/ │ ├── previews.json │ ├── reference-cloud.json │ ├── reference-community.json │ ├── reference-distilled.json │ ├── reference-quant.json │ ├── reference.json │ └── upscalers.json ├── eslint.config.mjs ├── html/ │ ├── art-styles.json │ ├── licenses.html │ ├── locale_de.json │ ├── locale_en.json │ ├── locale_es.json │ ├── locale_fr.json │ ├── locale_hr.json │ ├── locale_it.json │ ├── locale_ja.json │ ├── locale_ko.json │ ├── locale_pt.json │ ├── locale_ru.json │ ├── locale_zh.json │ ├── manifest.json │ ├── override_en.json │ ├── override_hr.json │ ├── override_ko.json │ └── swagger.css ├── installer.py ├── javascript/ │ ├── amethyst-nightfall.css │ ├── aspectRatioOverlay.js │ ├── authWrap.js │ ├── base.css │ ├── black-gray.css │ ├── black-orange.css │ ├── black-teal-reimagined.css │ ├── black-teal.css │ ├── changelog.js │ ├── civitai.js │ ├── contextMenus.js │ ├── control.js │ ├── docs.js │ ├── dragDrop.js │ ├── editAttention.js │ ├── emerald-paradise.css │ ├── exifr.js │ ├── extensions.js │ ├── extraNetworks.js │ ├── gallery.js │ ├── generationParams.js │ ├── gpu.js │ ├── guidance.js │ ├── hires.js │ ├── history.js │ ├── imageParams.js │ ├── imageViewer.js │ ├── indexdb.js │ ├── inputAccordion.js │ ├── invoked.css │ ├── light-teal.css │ ├── loader.js │ ├── logMonitor.js │ ├── logger.js │ ├── login.js │ ├── midnight-barbie.css │ ├── monitor.js │ ├── notification.js │ ├── orchid-dreams.css │ ├── panZoom.js │ ├── progressBar.js │ ├── promptChecker.js │ ├── script.js │ ├── sdnext.css │ ├── setHints.js │ ├── settings.js │ ├── simple-dark.css │ ├── simple-light.css │ ├── startup.js │ ├── timeless-beige.css │ ├── timesheet.css │ ├── timesheet.js │ ├── trainMonitor.js │ ├── ui.js │ └── uiConfig.js ├── launch.py ├── models/ │ └── VAE-approx/ │ └── model.pt ├── modules/ │ ├── apg/ │ │ ├── __init__.py │ │ ├── pipeline_stable_cascade_prior_apg.py │ │ ├── pipeline_stable_diffision_xl_apg.py │ │ └── pipeline_stable_diffusion_apg.py │ ├── api/ │ │ ├── api.py │ │ ├── control.py │ │ ├── docs.py │ │ ├── endpoints.py │ │ ├── gallery.py │ │ ├── generate.py │ │ ├── gpu.py │ │ ├── helpers.py │ │ ├── loras.py │ │ ├── middleware.py │ │ ├── mime.py │ │ ├── models.py │ │ ├── nudenet.py │ │ ├── nvml.py │ │ ├── process.py │ │ ├── rocm_smi.py │ │ ├── script.py │ │ ├── server.py │ │ ├── xpu_smi.py │ │ └── xyz_grid.py │ ├── attention.py │ ├── ben2/ │ │ ├── __init__.py │ │ └── ben2_model.py │ ├── cachedit.py │ ├── call_queue.py │ ├── cfgzero/ │ │ ├── __init__.py │ │ ├── cogview4_pipeline.py │ │ ├── flux_pipeline.py │ │ ├── hidream_pipeline.py │ │ ├── hunyuan_t2v_pipeline.py │ │ ├── sd3_pipeline.py │ │ └── wan_t2v_pipeline.py │ ├── civitai/ │ │ ├── api_civitai.py │ │ ├── download_civitai.py │ │ ├── metadata_civitai.py │ │ └── search_civitai.py │ ├── cmd_args.py │ ├── control/ │ │ ├── proc/ │ │ │ ├── __init__.py │ │ │ ├── canny.py │ │ │ ├── depth_anything/ │ │ │ │ ├── __init__.py │ │ │ │ ├── blocks.py │ │ │ │ ├── dpt.py │ │ │ │ └── util/ │ │ │ │ └── transform.py │ │ │ ├── depth_pro/ │ │ │ │ └── __init__.py │ │ │ ├── dpt.py │ │ │ ├── dwpose/ │ │ │ │ ├── __init__.py │ │ │ │ ├── config/ │ │ │ │ │ ├── dwpose-l_384x288.py │ │ │ │ │ ├── rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py │ │ │ │ │ ├── rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py │ │ │ │ │ ├── rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py │ │ │ │ │ └── yolox_l_8xb8-300e_coco.py │ │ │ │ ├── draw.py │ │ │ │ └── wholebody.py │ │ │ ├── edge.py │ │ │ ├── glpn.py │ │ │ ├── hed.py │ │ │ ├── leres/ │ │ │ │ ├── __init__.py │ │ │ │ ├── leres/ │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── Resnet.py │ │ │ │ │ ├── Resnext_torch.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── depthmap.py │ │ │ │ │ ├── multi_depth_model_woauxi.py │ │ │ │ │ ├── net_tools.py │ │ │ │ │ └── network_auxi.py │ │ │ │ └── pix2pix/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── models/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_model.py │ │ │ │ │ ├── base_model_hg.py │ │ │ │ │ ├── networks.py │ │ │ │ │ └── pix2pix4depth_model.py │ │ │ │ ├── options/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_options.py │ │ │ │ │ └── test_options.py │ │ │ │ └── util/ │ │ │ │ ├── __init__.py │ │ │ │ └── util.py │ │ │ ├── lineart.py │ │ │ ├── lineart_anime.py │ │ │ ├── marigold/ │ │ │ │ ├── __init__.py │ │ │ │ ├── marigold_pipeline.py │ │ │ │ └── util/ │ │ │ │ ├── batchsize.py │ │ │ │ ├── ensemble.py │ │ │ │ ├── image_util.py │ │ │ │ └── seed_all.py │ │ │ ├── mediapipe_face.py │ │ │ ├── mediapipe_face_util.py │ │ │ ├── midas/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── api.py │ │ │ │ ├── midas/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_model.py │ │ │ │ │ ├── blocks.py │ │ │ │ │ ├── dpt_depth.py │ │ │ │ │ ├── midas_net.py │ │ │ │ │ ├── midas_net_custom.py │ │ │ │ │ ├── transforms.py │ │ │ │ │ └── vit.py │ │ │ │ └── utils.py │ │ │ ├── mlsd/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── models/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── mbv2_mlsd_large.py │ │ │ │ │ └── mbv2_mlsd_tiny.py │ │ │ │ └── utils.py │ │ │ ├── normalbae/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ └── nets/ │ │ │ │ ├── NNET.py │ │ │ │ ├── __init__.py │ │ │ │ ├── baseline.py │ │ │ │ └── submodules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── decoder.py │ │ │ │ ├── efficientnet_repo/ │ │ │ │ │ ├── BENCHMARK.md │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── README.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── geffnet/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── activations/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── activations.py │ │ │ │ │ │ │ ├── activations_jit.py │ │ │ │ │ │ │ └── activations_me.py │ │ │ │ │ │ ├── config.py │ │ │ │ │ │ ├── conv2d_layers.py │ │ │ │ │ │ ├── efficientnet_builder.py │ │ │ │ │ │ ├── gen_efficientnet.py │ │ │ │ │ │ ├── helpers.py │ │ │ │ │ │ ├── mobilenetv3.py │ │ │ │ │ │ ├── model_factory.py │ │ │ │ │ │ └── version.py │ │ │ │ │ ├── hubconf.py │ │ │ │ │ ├── requirements.txt │ │ │ │ │ ├── setup.py │ │ │ │ │ ├── utils.py │ │ │ │ │ └── validate.py │ │ │ │ ├── encoder.py │ │ │ │ └── submodules.py │ │ │ ├── openpose/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── body.py │ │ │ │ ├── face.py │ │ │ │ ├── hand.py │ │ │ │ ├── model.py │ │ │ │ └── util.py │ │ │ ├── pidi.py │ │ │ ├── pidi_model.py │ │ │ ├── segment_anything/ │ │ │ │ ├── __init__.py │ │ │ │ ├── automatic_mask_generator.py │ │ │ │ ├── build_sam.py │ │ │ │ ├── modeling/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── common.py │ │ │ │ │ ├── image_encoder.py │ │ │ │ │ ├── mask_decoder.py │ │ │ │ │ ├── prompt_encoder.py │ │ │ │ │ ├── sam.py │ │ │ │ │ ├── tiny_vit_sam.py │ │ │ │ │ └── transformer.py │ │ │ │ ├── predictor.py │ │ │ │ └── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── amg.py │ │ │ │ ├── onnx.py │ │ │ │ └── transforms.py │ │ │ ├── shuffle.py │ │ │ └── zoe/ │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── zoedepth/ │ │ │ ├── __init__.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_models/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── midas.py │ │ │ │ │ └── midas_repo/ │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── README.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── hubconf.py │ │ │ │ │ └── midas/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── backbones/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── beit.py │ │ │ │ │ │ ├── levit.py │ │ │ │ │ │ ├── next_vit.py │ │ │ │ │ │ ├── swin.py │ │ │ │ │ │ ├── swin2.py │ │ │ │ │ │ ├── swin_common.py │ │ │ │ │ │ ├── utils.py │ │ │ │ │ │ └── vit.py │ │ │ │ │ ├── base_model.py │ │ │ │ │ ├── blocks.py │ │ │ │ │ ├── dpt_depth.py │ │ │ │ │ ├── midas_net.py │ │ │ │ │ ├── midas_net_custom.py │ │ │ │ │ ├── model_loader.py │ │ │ │ │ └── transforms.py │ │ │ │ ├── builder.py │ │ │ │ ├── depth_model.py │ │ │ │ ├── layers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── attractor.py │ │ │ │ │ ├── dist_layers.py │ │ │ │ │ ├── localbins_layers.py │ │ │ │ │ └── patch_transformer.py │ │ │ │ ├── model_io.py │ │ │ │ ├── zoedepth/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── config_zoedepth.json │ │ │ │ │ ├── config_zoedepth_kitti.json │ │ │ │ │ └── zoedepth_v1.py │ │ │ │ └── zoedepth_nk/ │ │ │ │ ├── __init__.py │ │ │ │ ├── config_zoedepth_nk.json │ │ │ │ └── zoedepth_nk_v1.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── arg_utils.py │ │ │ ├── config.py │ │ │ └── easydict/ │ │ │ └── __init__.py │ │ ├── processor.py │ │ ├── processors.py │ │ ├── run.py │ │ ├── test.py │ │ ├── tile.py │ │ ├── unit.py │ │ ├── units/ │ │ │ ├── controlnet.py │ │ │ ├── detect.py │ │ │ ├── lite.py │ │ │ ├── lite_model.py │ │ │ ├── reference.py │ │ │ ├── t2iadapter.py │ │ │ ├── xs.py │ │ │ ├── xs_model.py │ │ │ └── xs_pipe.py │ │ └── util.py │ ├── detailer.py │ ├── devices.py │ ├── devices_mac.py │ ├── dml/ │ │ ├── Generator.py │ │ ├── __init__.py │ │ ├── amp/ │ │ │ ├── __init__.py │ │ │ └── autocast_mode.py │ │ ├── backend.py │ │ ├── device.py │ │ ├── device_properties.py │ │ ├── hijack/ │ │ │ ├── __init__.py │ │ │ ├── realesrgan_model.py │ │ │ ├── tomesd.py │ │ │ ├── torch.py │ │ │ ├── transformers.py │ │ │ └── utils.py │ │ ├── memory.py │ │ ├── memory_amd/ │ │ │ ├── __init__.py │ │ │ └── driver/ │ │ │ ├── atiadlxx.py │ │ │ ├── atiadlxx_apis.py │ │ │ ├── atiadlxx_defines.py │ │ │ └── atiadlxx_structures.py │ │ ├── pdh/ │ │ │ ├── __init__.py │ │ │ ├── apis.py │ │ │ ├── defines.py │ │ │ ├── errors.py │ │ │ ├── msvcrt.py │ │ │ └── structures.py │ │ └── utils.py │ ├── errorlimiter.py │ ├── errors.py │ ├── extensions.py │ ├── extra_networks.py │ ├── extras.py │ ├── face/ │ │ ├── __init__.py │ │ ├── faceid.py │ │ ├── faceswap.py │ │ ├── insightface.py │ │ ├── instantid.py │ │ ├── instantid_model.py │ │ ├── photomaker.py │ │ ├── photomaker_model_v1.py │ │ ├── photomaker_model_v2.py │ │ ├── photomaker_pipeline.py │ │ ├── reswapper.py │ │ ├── reswapper_model.py │ │ └── reswapper_utils.py │ ├── face_restoration.py │ ├── facelib/ │ │ ├── __init__.py │ │ ├── detection/ │ │ │ ├── __init__.py │ │ │ ├── align_trans.py │ │ │ ├── matlab_cp2tform.py │ │ │ ├── retinaface/ │ │ │ │ ├── retinaface.py │ │ │ │ ├── retinaface_net.py │ │ │ │ └── retinaface_utils.py │ │ │ └── yolov5face/ │ │ │ ├── __init__.py │ │ │ ├── face_detector.py │ │ │ ├── models/ │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ ├── experimental.py │ │ │ │ ├── yolo.py │ │ │ │ ├── yolov5l.yaml │ │ │ │ └── yolov5n.yaml │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── autoanchor.py │ │ │ ├── datasets.py │ │ │ ├── extract_ckpt.py │ │ │ ├── general.py │ │ │ └── torch_utils.py │ │ ├── parsing/ │ │ │ ├── __init__.py │ │ │ ├── bisenet.py │ │ │ ├── parsenet.py │ │ │ └── resnet.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── face_restoration_helper.py │ │ ├── face_utils.py │ │ └── misc.py │ ├── files_cache.py │ ├── flash_attn_triton_amd/ │ │ ├── __init__.py │ │ ├── fwd_prefill.py │ │ ├── interface_fa.py │ │ └── utils.py │ ├── framepack/ │ │ ├── create-video.py │ │ ├── encode-video.py │ │ ├── framepack_api.py │ │ ├── framepack_hijack.py │ │ ├── framepack_install.py │ │ ├── framepack_load.py │ │ ├── framepack_ui.py │ │ ├── framepack_vae.py │ │ ├── framepack_worker.py │ │ ├── framepack_wrappers.py │ │ └── pipeline/ │ │ ├── bucket_tools.py │ │ ├── clip_vision.py │ │ ├── dit_common.py │ │ ├── hunyuan.py │ │ ├── hunyuan_video_packed.py │ │ ├── k_diffusion_hunyuan.py │ │ ├── thread_utils.py │ │ ├── uni_pc_fm.py │ │ ├── utils.py │ │ └── wrapper.py │ ├── generation_parameters_copypaste.py │ ├── ggml/ │ │ ├── __init__.py │ │ ├── gguf_tensor.py │ │ └── gguf_utils.py │ ├── gr_hijack.py │ ├── gr_tempdir.py │ ├── hashes.py │ ├── hidiffusion/ │ │ ├── __init__.py │ │ ├── hidiffusion.py │ │ ├── hidiffusion_controlnet.py │ │ └── utils.py │ ├── history.py │ ├── images.py │ ├── images_grid.py │ ├── images_namegen.py │ ├── images_resize.py │ ├── img2img.py │ ├── infotext.py │ ├── infotext_utils.py │ ├── intel/ │ │ ├── ipex/ │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── device_prop.py │ │ │ ├── diffusers.py │ │ │ └── hijacks.py │ │ └── openvino/ │ │ └── __init__.py │ ├── interrogate/ │ │ ├── deepbooru.py │ │ ├── deepbooru_model.py │ │ ├── deepseek.py │ │ ├── interrogate.py │ │ ├── joycaption.py │ │ ├── joytag.py │ │ ├── moondream3.py │ │ ├── openclip.py │ │ ├── tagger.py │ │ ├── vqa.py │ │ ├── vqa_detection.py │ │ └── waifudiffusion.py │ ├── ipadapter.py │ ├── json_helpers.py │ ├── lama.py │ ├── linfusion/ │ │ ├── __init__.py │ │ ├── attention.py │ │ └── linfusion.py │ ├── loader.py │ ├── localization.py │ ├── lora/ │ │ ├── extra_networks_lora.py │ │ ├── lora_apply.py │ │ ├── lora_common.py │ │ ├── lora_convert.py │ │ ├── lora_diffusers.py │ │ ├── lora_extract.py │ │ ├── lora_load.py │ │ ├── lora_nunchaku.py │ │ ├── lora_overrides.py │ │ ├── lora_timers.py │ │ ├── lyco_helpers.py │ │ ├── network.py │ │ ├── network_full.py │ │ ├── network_glora.py │ │ ├── network_hada.py │ │ ├── network_ia3.py │ │ ├── network_lokr.py │ │ ├── network_lora.py │ │ ├── network_norm.py │ │ ├── network_oft.py │ │ └── networks.py │ ├── ltx/ │ │ ├── ltx_process.py │ │ ├── ltx_ui.py │ │ └── ltx_util.py │ ├── masking.py │ ├── memmon.py │ ├── memstats.py │ ├── merging/ │ │ ├── convert_sdxl.py │ │ ├── merge.py │ │ ├── merge_PermSpec.py │ │ ├── merge_PermSpec_SDXL.py │ │ ├── merge_methods.py │ │ ├── merge_presets.py │ │ ├── merge_rebasin.py │ │ ├── merge_utils.py │ │ └── modules_sdxl.py │ ├── migrate.py │ ├── mit_nunchaku.py │ ├── model_quant.py │ ├── model_te.py │ ├── model_tools.py │ ├── modeldata.py │ ├── modelloader.py │ ├── models_hf.py │ ├── modelstats.py │ ├── modular.py │ ├── modular_guiders.py │ ├── olive_script.py │ ├── onnx_impl/ │ │ ├── __init__.py │ │ ├── execution_providers.py │ │ ├── pipelines/ │ │ │ ├── __init__.py │ │ │ ├── onnx_stable_diffusion_img2img_pipeline.py │ │ │ ├── onnx_stable_diffusion_inpaint_pipeline.py │ │ │ ├── onnx_stable_diffusion_pipeline.py │ │ │ ├── onnx_stable_diffusion_upscale_pipeline.py │ │ │ ├── onnx_stable_diffusion_xl_img2img_pipeline.py │ │ │ ├── onnx_stable_diffusion_xl_pipeline.py │ │ │ └── utils.py │ │ ├── ui.py │ │ └── utils.py │ ├── options.py │ ├── options_handler.py │ ├── pag/ │ │ ├── __init__.py │ │ ├── pipe_sd.py │ │ └── pipe_sdxl.py │ ├── para_attention.py │ ├── patches.py │ ├── paths.py │ ├── paths_internal.py │ ├── postprocess/ │ │ ├── aurasr_arch.py │ │ ├── aurasr_model.py │ │ ├── codeformer_arch.py │ │ ├── codeformer_model.py │ │ ├── dcc.py │ │ ├── esrgan_model.py │ │ ├── esrgan_model_arch.py │ │ ├── gfpgan_model.py │ │ ├── hqx.py │ │ ├── icbi.py │ │ ├── pixelart.py │ │ ├── realesrgan_model.py │ │ ├── realesrgan_model_arch.py │ │ ├── restorer.py │ │ ├── scunet_model.py │ │ ├── scunet_model_arch.py │ │ ├── sdupscaler_model.py │ │ ├── seedvr_model.py │ │ ├── swinir_model.py │ │ ├── swinir_model_arch.py │ │ ├── swinir_model_arch_v2.py │ │ ├── vqgan_arch.py │ │ └── yolo.py │ ├── postprocessing.py │ ├── processing.py │ ├── processing_args.py │ ├── processing_callbacks.py │ ├── processing_class.py │ ├── processing_correction.py │ ├── processing_diffusers.py │ ├── processing_helpers.py │ ├── processing_info.py │ ├── processing_prompt.py │ ├── processing_vae.py │ ├── progress.py │ ├── prompt_parser.py │ ├── prompt_parser_diffusers.py │ ├── prompt_parser_xhinker.py │ ├── ras/ │ │ ├── __init__.py │ │ ├── ras_attention.py │ │ ├── ras_forward.py │ │ ├── ras_manager.py │ │ └── ras_scheduler.py │ ├── res4lyf/ │ │ ├── __init__.py │ │ ├── abnorsett_scheduler.py │ │ ├── bong_tangent_scheduler.py │ │ ├── common_sigma_scheduler.py │ │ ├── deis_scheduler_alt.py │ │ ├── etdrk_scheduler.py │ │ ├── gauss_legendre_scheduler.py │ │ ├── langevin_dynamics_scheduler.py │ │ ├── lawson_scheduler.py │ │ ├── linear_rk_scheduler.py │ │ ├── lobatto_scheduler.py │ │ ├── pec_scheduler.py │ │ ├── phi_functions.py │ │ ├── radau_iia_scheduler.py │ │ ├── res_multistep_scheduler.py │ │ ├── res_multistep_sde_scheduler.py │ │ ├── res_singlestep_scheduler.py │ │ ├── res_singlestep_sde_scheduler.py │ │ ├── res_unified_scheduler.py │ │ ├── riemannian_flow_scheduler.py │ │ ├── rungekutta_44s_scheduler.py │ │ ├── rungekutta_57s_scheduler.py │ │ ├── rungekutta_67s_scheduler.py │ │ ├── scheduler_utils.py │ │ ├── simple_exponential_scheduler.py │ │ ├── specialized_rk_scheduler.py │ │ └── variants.py │ ├── rife/ │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── model_ifnet.py │ │ ├── model_rife.py │ │ ├── refine.py │ │ ├── ssim.py │ │ └── warplayer.py │ ├── rocm.py │ ├── rocm_triton_windows.py │ ├── safe.py │ ├── schedulers/ │ │ ├── perflow/ │ │ │ ├── __init__.py │ │ │ ├── pfode_solver.py │ │ │ ├── scheduler_perflow.py │ │ │ └── utils_perflow.py │ │ ├── scheduler_bdia.py │ │ ├── scheduler_dc.py │ │ ├── scheduler_dpm_flowmatch.py │ │ ├── scheduler_flashflow.py │ │ ├── scheduler_tcd.py │ │ ├── scheduler_tdd.py │ │ ├── scheduler_ufogen.py │ │ ├── scheduler_unipc_flowmatch.py │ │ └── scheduler_vdm.py │ ├── script_callbacks.py │ ├── script_loading.py │ ├── scripts.py │ ├── scripts_auto_postprocessing.py │ ├── scripts_manager.py │ ├── scripts_postprocessing.py │ ├── sd_checkpoint.py │ ├── sd_detect.py │ ├── sd_hijack.py │ ├── sd_hijack_accelerate.py │ ├── sd_hijack_dynamic_atten.py │ ├── sd_hijack_freeu.py │ ├── sd_hijack_hypertile.py │ ├── sd_hijack_safetensors.py │ ├── sd_hijack_te.py │ ├── sd_hijack_utils.py │ ├── sd_hijack_vae.py │ ├── sd_models.py │ ├── sd_models_compile.py │ ├── sd_models_utils.py │ ├── sd_modules.py │ ├── sd_offload.py │ ├── sd_samplers.py │ ├── sd_samplers_common.py │ ├── sd_samplers_diffusers.py │ ├── sd_te_remote.py │ ├── sd_unet.py │ ├── sd_vae.py │ ├── sdnq/ │ │ ├── __init__.py │ │ ├── common.py │ │ ├── dequantizer.py │ │ ├── file_loader.py │ │ ├── forward.py │ │ ├── layers/ │ │ │ ├── __init__.py │ │ │ ├── conv/ │ │ │ │ ├── conv_fp16.py │ │ │ │ ├── conv_fp8.py │ │ │ │ ├── conv_fp8_tensorwise.py │ │ │ │ ├── conv_int8.py │ │ │ │ └── forward.py │ │ │ └── linear/ │ │ │ ├── forward.py │ │ │ ├── linear_fp16.py │ │ │ ├── linear_fp8.py │ │ │ ├── linear_fp8_tensorwise.py │ │ │ └── linear_int8.py │ │ ├── loader.py │ │ ├── packed_float.py │ │ ├── packed_int.py │ │ ├── quantizer.py │ │ └── triton_mm.py │ ├── seedvr/ │ │ ├── __init__.py │ │ ├── config_3b.yaml │ │ ├── config_7b.yaml │ │ ├── rotary_embedding.py │ │ ├── src/ │ │ │ ├── __init__.py │ │ │ ├── common/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cache.py │ │ │ │ ├── config.py │ │ │ │ ├── decorators.py │ │ │ │ ├── diffusion/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── config.py │ │ │ │ │ ├── samplers/ │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ └── euler.py │ │ │ │ │ ├── schedules/ │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ └── lerp.py │ │ │ │ │ ├── timesteps/ │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ └── sampling/ │ │ │ │ │ │ └── trailing.py │ │ │ │ │ ├── types.py │ │ │ │ │ └── utils.py │ │ │ │ ├── distributed/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── advanced.py │ │ │ │ │ ├── basic.py │ │ │ │ │ ├── meta_init_utils.py │ │ │ │ │ └── ops.py │ │ │ │ ├── half_precision_fixes.py │ │ │ │ ├── logger.py │ │ │ │ ├── partition.py │ │ │ │ └── seed.py │ │ │ ├── core/ │ │ │ │ ├── __init__.py │ │ │ │ ├── generation.py │ │ │ │ ├── infer.py │ │ │ │ └── model_manager.py │ │ │ ├── data/ │ │ │ │ └── image/ │ │ │ │ └── transforms/ │ │ │ │ ├── area_resize.py │ │ │ │ ├── divisible_crop.py │ │ │ │ ├── na_resize.py │ │ │ │ └── side_resize.py │ │ │ ├── models/ │ │ │ │ ├── dit/ │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── blocks/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── mmdit_window_block.py │ │ │ │ │ ├── embedding.py │ │ │ │ │ ├── mlp.py │ │ │ │ │ ├── mm.py │ │ │ │ │ ├── modulation.py │ │ │ │ │ ├── na.py │ │ │ │ │ ├── nablocks/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── mmsr_block.py │ │ │ │ │ ├── nadit.py │ │ │ │ │ ├── normalization.py │ │ │ │ │ ├── patch.py │ │ │ │ │ ├── rope.py │ │ │ │ │ └── window.py │ │ │ │ ├── dit_v2/ │ │ │ │ │ ├── attention.py │ │ │ │ │ ├── embedding.py │ │ │ │ │ ├── mlp.py │ │ │ │ │ ├── mm.py │ │ │ │ │ ├── modulation.py │ │ │ │ │ ├── na.py │ │ │ │ │ ├── nablocks/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── attention/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── mmattn.py │ │ │ │ │ │ └── mmsr_block.py │ │ │ │ │ ├── nadit.py │ │ │ │ │ ├── normalization.py │ │ │ │ │ ├── patch/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── patch_v1.py │ │ │ │ │ ├── rope.py │ │ │ │ │ └── window.py │ │ │ │ └── video_vae_v3/ │ │ │ │ ├── modules/ │ │ │ │ │ ├── attn_video_vae.py │ │ │ │ │ ├── causal_inflation_lib.py │ │ │ │ │ ├── context_parallel_lib.py │ │ │ │ │ ├── global_config.py │ │ │ │ │ ├── inflated_layers.py │ │ │ │ │ ├── inflated_lib.py │ │ │ │ │ ├── types.py │ │ │ │ │ └── video_vae.py.old │ │ │ │ └── s8_c16_t4_inflation_sd3.yaml │ │ │ ├── optimization/ │ │ │ │ ├── __init__.py │ │ │ │ ├── memory_manager.py │ │ │ │ └── performance.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ └── color_fix.py │ │ └── test.py │ ├── server.py │ ├── shared.py │ ├── shared_defaults.py │ ├── shared_helpers.py │ ├── shared_items.py │ ├── shared_legacy.py │ ├── shared_state.py │ ├── styles.py │ ├── sub_quadratic_attention.py │ ├── taesd/ │ │ ├── hybrid_small.py │ │ ├── taehv.py │ │ ├── taem1.py │ │ └── taesd.py │ ├── teacache/ │ │ ├── __init__.py │ │ ├── teacache_chroma.py │ │ ├── teacache_cogvideox.py │ │ ├── teacache_flux.py │ │ ├── teacache_hidream.py │ │ ├── teacache_ltx.py │ │ ├── teacache_lumina2.py │ │ └── teacache_mochi.py │ ├── textual_inversion.py │ ├── theme.py │ ├── timer.py │ ├── todo/ │ │ ├── __init__.py │ │ ├── todo_merge.py │ │ └── todo_utils.py │ ├── token_merge.py │ ├── transformer_cache.py │ ├── txt2img.py │ ├── ui.py │ ├── ui_caption.py │ ├── ui_common.py │ ├── ui_components.py │ ├── ui_control.py │ ├── ui_control_elements.py │ ├── ui_control_helpers.py │ ├── ui_docs.py │ ├── ui_extensions.py │ ├── ui_extra_networks.py │ ├── ui_extra_networks_checkpoints.py │ ├── ui_extra_networks_history.py │ ├── ui_extra_networks_lora.py │ ├── ui_extra_networks_styles.py │ ├── ui_extra_networks_textual_inversion.py │ ├── ui_extra_networks_vae.py │ ├── ui_extra_networks_wildcards.py │ ├── ui_gallery.py │ ├── ui_guidance.py │ ├── ui_history.py │ ├── ui_img2img.py │ ├── ui_javascript.py │ ├── ui_loadsave.py │ ├── ui_models.py │ ├── ui_models_load.py │ ├── ui_postprocessing.py │ ├── ui_prompt_styles.py │ ├── ui_sections.py │ ├── ui_settings.py │ ├── ui_symbols.py │ ├── ui_txt2img.py │ ├── ui_video.py │ ├── ui_video_vlm.py │ ├── update.py │ ├── upscaler.py │ ├── upscaler_algo.py │ ├── upscaler_simple.py │ ├── upscaler_spandrel.py │ ├── upscaler_vae.py │ ├── vae/ │ │ ├── sd_vae_approx.py │ │ ├── sd_vae_fal.py │ │ ├── sd_vae_natten.py │ │ ├── sd_vae_ostris.py │ │ ├── sd_vae_remote.py │ │ ├── sd_vae_repa.py │ │ ├── sd_vae_stablecascade.py │ │ └── sd_vae_taesd.py │ ├── video.py │ ├── video_models/ │ │ ├── google_veo.py │ │ ├── models_def.py │ │ ├── video_cache.py │ │ ├── video_load.py │ │ ├── video_overrides.py │ │ ├── video_prompt.py │ │ ├── video_run.py │ │ ├── video_save.py │ │ ├── video_ui.py │ │ ├── video_utils.py │ │ └── video_vae.py │ ├── zluda.py │ └── zluda_installer.py ├── motd ├── package.json ├── pipelines/ │ ├── bria/ │ │ ├── __init__.py │ │ ├── bria_pipeline.py │ │ ├── bria_utils.py │ │ ├── transformer_block.py │ │ └── transformer_bria.py │ ├── f_lite/ │ │ ├── __init__.py │ │ ├── f_lite.model.py │ │ ├── model.py │ │ └── pipeline.py │ ├── flex2/ │ │ └── __init__.py │ ├── flux/ │ │ ├── flux_bnb.py │ │ ├── flux_legacy_loader.py │ │ ├── flux_lora.py │ │ ├── flux_nf4.py │ │ ├── flux_nunchaku.py │ │ └── flux_quanto.py │ ├── generic.py │ ├── hdm/ │ │ ├── __init__.py │ │ ├── hdm/ │ │ │ ├── __init__.py │ │ │ ├── data/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── kohya.py │ │ │ ├── loader.py │ │ │ ├── modules/ │ │ │ │ ├── base.py │ │ │ │ ├── rope.py │ │ │ │ ├── text_encoders.py │ │ │ │ ├── unet_patch.py │ │ │ │ └── xut.py │ │ │ ├── pipeline.py │ │ │ ├── trainer/ │ │ │ │ ├── __init__.py │ │ │ │ ├── callbacks.py │ │ │ │ ├── diffusion.py │ │ │ │ └── trainer.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ └── config.py │ │ └── xut/ │ │ ├── __init__.py │ │ ├── env.py │ │ ├── modules/ │ │ │ ├── __init__.py │ │ │ ├── adaln.py │ │ │ ├── attention.py │ │ │ ├── axial_rope.py │ │ │ ├── layers.py │ │ │ ├── norm.py │ │ │ ├── patch.py │ │ │ ├── time_emb.py │ │ │ └── transformer.py │ │ ├── utils/ │ │ │ └── __init__.py │ │ └── xut.py │ ├── hidream/ │ │ └── pipeline_hidream_image_editing.py │ ├── meissonic/ │ │ ├── __init__.py │ │ ├── pipeline.py │ │ ├── pipeline_img2img.py │ │ ├── pipeline_inpaint.py │ │ ├── scheduler.py │ │ ├── test.py │ │ └── transformer.py │ ├── model_anima.py │ ├── model_auraflow.py │ ├── model_bria.py │ ├── model_chroma.py │ ├── model_chrono.py │ ├── model_cogview.py │ ├── model_cosmos.py │ ├── model_flex.py │ ├── model_flite.py │ ├── model_flux.py │ ├── model_flux2.py │ ├── model_flux2_klein.py │ ├── model_glm.py │ ├── model_google.py │ ├── model_hdm.py │ ├── model_hidream.py │ ├── model_hunyuandit.py │ ├── model_hyimage.py │ ├── model_kandinsky.py │ ├── model_kolors.py │ ├── model_longcat.py │ ├── model_lumina.py │ ├── model_meissonic.py │ ├── model_nextstep.py │ ├── model_omnigen.py │ ├── model_ovis.py │ ├── model_pixart.py │ ├── model_prx.py │ ├── model_qwen.py │ ├── model_sana.py │ ├── model_sd3.py │ ├── model_stablecascade.py │ ├── model_wanai.py │ ├── model_xomni.py │ ├── model_z_image.py │ ├── omnigen2/ │ │ ├── __init__.py │ │ ├── image_processor.py │ │ ├── models/ │ │ │ ├── attention_processor.py │ │ │ ├── embeddings.py │ │ │ └── transformers/ │ │ │ ├── __init__.py │ │ │ ├── block_lumina2.py │ │ │ ├── repo.py │ │ │ └── transformer_omnigen2.py │ │ └── pipeline_omnigen2.py │ ├── qwen/ │ │ ├── __init__.py │ │ ├── qwen_nunchaku.py │ │ └── qwen_pruning.py │ ├── segmoe/ │ │ └── segmoe_model.py │ ├── wan/ │ │ └── wan_image.py │ └── xomni/ │ ├── __init__.py │ ├── configuration_xomni.py │ ├── modeling_siglip_flux.py │ ├── modeling_siglip_tokenizer.py │ ├── modeling_vit.py │ └── modeling_xomni.py ├── requirements.txt ├── scripts/ │ ├── animatediff.py │ ├── apg.py │ ├── automatic_color_inpaint.py │ ├── blipdiffusion.py │ ├── consistory/ │ │ ├── __init__.py │ │ ├── attention_processor.py │ │ ├── consistory_pipeline.py │ │ ├── consistory_run.py │ │ ├── consistory_unet_sdxl.py │ │ ├── consistory_utils.py │ │ └── utils/ │ │ ├── general_utils.py │ │ └── ptp_utils.py │ ├── consistory_ext.py │ ├── ctrlx/ │ │ ├── __init__.py │ │ ├── features.py │ │ ├── media.py │ │ ├── sdxl.py │ │ └── utils.py │ ├── ctrlx_ext.py │ ├── custom_code.py │ ├── daam/ │ │ ├── __init__.py │ │ ├── evaluate.py │ │ ├── experiment.py │ │ ├── heatmap.py │ │ ├── hook.py │ │ ├── trace.py │ │ └── utils.py │ ├── daam_ext.py │ ├── demofusion.py │ ├── differential_diffusion.py │ ├── example.py │ ├── flux_enhance.py │ ├── flux_tools.py │ ├── freescale/ │ │ ├── __init__.py │ │ ├── free_lunch_utils.py │ │ ├── freescale_pipeline.py │ │ ├── freescale_pipeline_img2img.py │ │ └── scale_attention.py │ ├── freescale_ext.py │ ├── hdr.py │ ├── image2video.py │ ├── infiniteyou/ │ │ ├── __init__.py │ │ ├── pipeline_flux_infusenet.py │ │ ├── pipeline_infu_flux.py │ │ └── resampler.py │ ├── infiniteyou_ext.py │ ├── init_latents.py │ ├── instantir/ │ │ ├── __init__.py │ │ ├── aggregator.py │ │ ├── ip_adapter/ │ │ │ ├── __init__.py │ │ │ ├── attention_processor.py │ │ │ ├── ip_adapter.py │ │ │ ├── resampler.py │ │ │ └── utils.py │ │ ├── lcm_single_step_scheduler.py │ │ └── sdxl_instantir.py │ ├── instantir_ext.py │ ├── ipadapter.py │ ├── ipinstruct.py │ ├── kohya_hires_fix.py │ ├── layerdiffuse/ │ │ ├── __init__.py │ │ ├── layerdiffuse_loader.py │ │ └── layerdiffuse_model.py │ ├── layerdiffuse_ext.py │ ├── lbm/ │ │ ├── __init__.py │ │ ├── base/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── model_config.py │ │ ├── config.py │ │ ├── embedders/ │ │ │ ├── __init__.py │ │ │ ├── base/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base_conditioner.py │ │ │ │ └── base_conditioner_config.py │ │ │ ├── conditioners_wrapper.py │ │ │ └── latents_concat/ │ │ │ ├── __init__.py │ │ │ ├── latents_concat_embedder_config.py │ │ │ └── latents_concat_embedder_model.py │ │ ├── extract.py │ │ ├── inference.py │ │ ├── lbm/ │ │ │ ├── __init__.py │ │ │ ├── lbm_config.py │ │ │ └── lbm_model.py │ │ ├── tiler.py │ │ ├── unets/ │ │ │ ├── __init__.py │ │ │ └── unet.py │ │ ├── utils.py │ │ └── vae/ │ │ ├── __init__.py │ │ ├── autoencoderKL.py │ │ └── autoencoderKL_config.py │ ├── lbm_ext.py │ ├── ledits.py │ ├── loopback.py │ ├── lut.py │ ├── mixture_of_diffusers.py │ ├── mixture_tiling.py │ ├── mod/ │ │ └── __init__.py │ ├── mulan.py │ ├── nudenet/ │ │ ├── bannedwords.py │ │ ├── imageguard.py │ │ ├── langdetect.py │ │ └── nudenet.py │ ├── nudenet_ext.py │ ├── outpainting_mk_2.py │ ├── pixelsmith/ │ │ ├── __init__.py │ │ ├── autoencoder_kl.py │ │ ├── pixelsmith_pipeline.py │ │ └── vae.py │ ├── pixelsmith_ext.py │ ├── poor_mans_outpainting.py │ ├── postprocessing_codeformer.py │ ├── postprocessing_gfpgan.py │ ├── postprocessing_pixelart.py │ ├── postprocessing_upscale.py │ ├── postprocessing_video.py │ ├── prompt_enhance.py │ ├── prompt_matrix.py │ ├── prompts_from_file.py │ ├── pulid/ │ │ ├── __init__.py │ │ ├── attention_processor.py │ │ ├── encoders_transformer.py │ │ ├── eva_clip/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── eva_vit_model.py │ │ │ ├── factory.py │ │ │ ├── hf_configs.py │ │ │ ├── hf_model.py │ │ │ ├── loss.py │ │ │ ├── model.py │ │ │ ├── model_configs/ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ └── EVA02-CLIP-bigE-14.json │ │ │ ├── modified_resnet.py │ │ │ ├── openai.py │ │ │ ├── pretrained.py │ │ │ ├── rope.py │ │ │ ├── timm_model.py │ │ │ ├── tokenizer.py │ │ │ ├── transform.py │ │ │ ├── transformer.py │ │ │ └── utils.py │ │ ├── pulid_flux.py │ │ ├── pulid_sampling.py │ │ ├── pulid_sdxl.py │ │ └── pulid_utils.py │ ├── pulid_ext.py │ ├── regional_prompting.py │ ├── resadapter.py │ ├── sd_upscale.py │ ├── skip_layer_guidance.py │ ├── softfill.py │ ├── stablevideodiffusion.py │ ├── style_aligned/ │ │ ├── inversion.py │ │ └── sa_handler.py │ ├── style_aligned_ext.py │ ├── t_gate.py │ ├── text2video.py │ ├── tiling.py │ ├── xadapter/ │ │ ├── adapter.py │ │ ├── pipeline_sd_xl_adapter.py │ │ ├── pipeline_sd_xl_adapter_controlnet.py │ │ ├── pipeline_sd_xl_adapter_controlnet_img2img.py │ │ ├── unet_adapter.py │ │ ├── utils.py │ │ └── xadapter_hijacks.py │ ├── xadapter_ext.py │ ├── xyz/ │ │ ├── xyz_grid_classes.py │ │ ├── xyz_grid_draw.py │ │ └── xyz_grid_shared.py │ ├── xyz_grid.py │ └── xyz_grid_on.py ├── webui.bat ├── webui.ps1 ├── webui.py └── webui.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ # defaults .history .vscode/ /__pycache__ /.ruff_cache /cache /cache.json /config.json /extensions/* /html/extensions.json /html/themes.json /metadata.json /node_modules /outputs/* /package-lock.json /params.txt /pnpm-lock.yaml /styles.csv /tmp /ui-config.json /user.css /venv /webui-user.bat /webui-user.sh /*.log.* /*.log ================================================ FILE: .gitignore ================================================ # defaults venv/ __pycache__ .ruff_cache /*.json /*.yaml /params.txt /styles.csv /user.css /webui-user.bat /webui-user.sh /data/metadata.json /data/extensions.json /data/cache.json /data/themes.json config_states node_modules pnpm-lock.yaml package-lock.json .history cache **/.DS_Store tunableop_results*.csv # all models and temp files *.log *.log.* *.bak *.ckpt *.safetensors *.pth *.pt *.bin *.optim *.lock *.zip *.rar *.7z *.pyc /*.bat /*.sh /*.txt /*.mp3 /*.lnk /*.swp !webui.bat !webui.sh !package.json !requirements.txt # all dynamic stuff /extensions/**/* /outputs/**/* /embeddings/**/* /models/**/* /interrogate/**/* /train/log/**/* /textual_inversion/**/* /detected_maps/**/* /tmp /log /cert .vscode/ .idea/ /localizations .*/ # force included !/data !/models/VAE-approx !/models/VAE-approx/model.pt !/models/Reference !/models/Reference/**/* ================================================ FILE: .gitmodules ================================================ [submodule "wiki"] path = wiki url = https://github.com/vladmandic/sdnext.wiki ignore = dirty [submodule "extensions-builtin/sd-extension-system-info"] path = extensions-builtin/sd-extension-system-info url = https://github.com/vladmandic/sd-extension-system-info ignore = dirty [submodule "extensions-builtin/sd-extension-chainner"] path = extensions-builtin/sd-extension-chainner url = https://github.com/vladmandic/sd-extension-chainner ignore = dirty [submodule "extensions-builtin/stable-diffusion-webui-rembg"] path = extensions-builtin/stable-diffusion-webui-rembg url = https://github.com/vladmandic/sd-extension-rembg ignore = dirty [submodule "extensions-builtin/sdnext-modernui"] path = extensions-builtin/sdnext-modernui url = https://github.com/BinaryQuantumSoul/sdnext-modernui [submodule "extensions-builtin/sdnext-kanvas"] path = extensions-builtin/sdnext-kanvas url = https://github.com/vladmandic/sdnext-kanvas ================================================ FILE: .pylintrc ================================================ [MAIN] analyse-fallback-blocks=no clear-cache-post-run=no extension-pkg-allow-list= prefer-stubs=yes extension-pkg-whitelist= fail-on= fail-under=10 ignore=CVS ignore-paths=/usr/lib/.*$, venv, .git, .ruff_cache, .vscode, modules/apg, modules/cfgzero, modules/control/proc, modules/control/units, modules/dml, modules/facelib, modules/flash_attn_triton_amd, modules/ggml, modules/hidiffusion, modules/hijack/ddpm_edit.py, modules/intel, modules/intel/ipex, modules/framepack/pipeline, modules/onnx_impl, modules/pag, modules/postprocess/aurasr_arch.py, modules/prompt_parser_xhinker.py, modules/ras, modules/seedvr, modules/rife, modules/schedulers, modules/taesd, modules/teacache, modules/todo, modules/res4lyf, pipelines/bria, pipelines/flex2, pipelines/f_lite, pipelines/hidream, pipelines/hdm, pipelines/meissonic, pipelines/omnigen2, pipelines/segmoe, pipelines/xomni, pipelines/chrono, scripts/consistory, scripts/ctrlx, scripts/daam, scripts/demofusion, scripts/freescale, scripts/infiniteyou, scripts/instantir, scripts/lbm, scripts/layerdiffuse, scripts/mod, scripts/pixelsmith, scripts/differential_diffusion.py, scripts/pulid, scripts/xadapter, repositories, extensions-builtin/sd-extension-chainner/nodes, extensions-builtin/sd-webui-agent-scheduler, extensions-builtin/sdnext-modernui/node_modules, extensions-builtin/sdnext-kanvas/node_modules, ignore-patterns=.*test*.py$, .*_model.py$, .*_arch.py$, .*_model_arch.py*, .*_model_arch_v2.py$, ignored-modules= jobs=8 limit-inference-results=100 load-plugins= persistent=no py-version=3.10 recursive=no source-roots= unsafe-load-any-extension=no [BASIC] argument-naming-style=snake_case attr-naming-style=snake_case bad-names=foo, bar, baz, toto, tutu, tata bad-names-rgxs= class-attribute-naming-style=any class-const-naming-style=UPPER_CASE class-naming-style=PascalCase const-naming-style=snake_case docstring-min-length=-1 function-naming-style=snake_case good-names=i,j,k,e,ex,ok,p,x,y,id good-names-rgxs= include-naming-hint=no inlinevar-naming-style=any method-naming-style=snake_case module-naming-style=snake_case name-group= no-docstring-rgx=^_ property-classes=abc.abstractproperty variable-naming-style=snake_case [CLASSES] check-protected-access-in-special-methods=no defining-attr-methods=__init__, __new__, exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit valid-classmethod-first-arg=cls valid-metaclass-classmethod-first-arg=mcs [DESIGN] exclude-too-few-public-methods= ignored-parents= max-args=199 max-attributes=99 max-bool-expr=99 max-branches=199 max-locals=99 max-parents=99 max-public-methods=99 max-returns=99 max-statements=199 min-public-methods=1 [EXCEPTIONS] overgeneral-exceptions=builtins.BaseException,builtins.Exception [FORMAT] expected-line-ending-format= ignore-long-lines=^\s*(# )??$ indent-after-paren=4 indent-string=' ' max-line-length=200 max-module-lines=9999 single-line-class-stmt=no single-line-if-stmt=no [IMPORTS] allow-any-import-level= allow-reexport-from-package=no allow-wildcard-with-all=no deprecated-modules= ext-import-graph= import-graph= int-import-graph= known-standard-library= known-third-party=enchant preferred-modules= [LOGGING] logging-format-style=new logging-modules=logging [MESSAGES CONTROL] confidence=HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED # disable=C,R,W disable=abstract-method, bad-inline-option, bare-except, broad-exception-caught, chained-comparison, consider-iterating-dictionary, consider-merging-isinstance, consider-using-dict-items, consider-using-enumerate, consider-using-from-import, consider-using-generator, consider-using-get, consider-using-in, consider-using-max-builtin, consider-using-min-builtin, consider-using-sys-exit, cyclic-import, dangerous-default-value, deprecated-pragma, duplicate-code, file-ignored, import-error, import-outside-toplevel, invalid-name, line-too-long, locally-disabled, logging-fstring-interpolation, missing-class-docstring, missing-function-docstring, missing-module-docstring, no-else-raise, no-else-return, not-callable, pointless-string-statement, raw-checker-failed, simplifiable-if-expression, suppressed-message, too-few-public-methods, too-many-instance-attributes, too-many-locals, too-many-nested-blocks, too-many-positional-arguments, too-many-statements, unidiomatic-typecheck, unknown-option-value, unnecessary-dict-index-lookup, unnecessary-dunder-call, unnecessary-lambda-assigment, unnecessary-lambda, unused-wildcard-import, unpacking-non-sequence, unsubscriptable-object, useless-return, use-dict-literal, use-symbolic-message-instead, useless-suppression, wrong-import-position, enable=c-extension-no-member [METHOD_ARGS] timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request [MISCELLANEOUS] notes=FIXME, XXX, TODO notes-rgx= [REFACTORING] max-nested-blocks=5 never-returning-functions=sys.exit,argparse.parse_error [REPORTS] evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) msg-template= reports=no score=no [SIMILARITIES] ignore-comments=yes ignore-docstrings=yes ignore-imports=yes ignore-signatures=yes min-similarity-lines=4 [SPELLING] max-spelling-suggestions=4 spelling-dict= spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: spelling-ignore-words= spelling-private-dict-file= spelling-store-unknown-words=no [STRING] check-quote-consistency=no check-str-concat-over-line-jumps=no [TYPECHECK] contextmanager-decorators=contextlib.contextmanager generated-members=numpy.*,logging.*,torch.*,cv2.* ignore-none=yes ignore-on-opaque-inference=yes ignored-checks-for-mixins=no-member, not-async-context-manager, not-context-manager, attribute-defined-outside-init ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace missing-member-hint=yes missing-member-hint-distance=1 missing-member-max-choices=1 mixin-class-rgx=.*[Mm]ixin signature-mutators= [VARIABLES] additional-builtins= allow-global-unused-variables=yes allowed-redefined-builtins= callbacks=cb_, dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ ignored-argument-names=_.*|^ignored_|^unused_ init-import=no redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io ================================================ FILE: .ruff.toml ================================================ line-length = 250 indent-width = 4 target-version = "py310" exclude = [ "venv", ".git", ".ruff_cache", ".vscode", "modules/cfgzero", "modules/facelib", "modules/flash_attn_triton_amd", "modules/hidiffusion", "modules/intel/ipex", "modules/pag", "modules/schedulers", "modules/teacache", "modules/seedvr", "modules/control/proc", "modules/control/units", "modules/control/units/xs_pipe.py", "modules/postprocess/aurasr_arch.py", "pipelines/meissonic", "pipelines/omnigen2", "pipelines/hdm", "pipelines/segmoe", "pipelines/xomni", "pipelines/chrono", "scripts/lbm", "scripts/daam", "scripts/xadapter", "scripts/pulid", "scripts/instantir", "scripts/freescale", "scripts/consistory", "repositories", "extensions-builtin/Lora", "extensions-builtin/sd-extension-chainner/nodes", "extensions-builtin/sd-webui-agent-scheduler", "extensions-builtin/sdnext-modernui/node_modules", ] [lint] select = [ "F", "E", "W", "C", "B", "I", "YTT", "ASYNC", "RUF", "AIR", "NPY", "C4", "T10", "EXE", "ISC", "ICN", "RSE", "TCH", "TID", "INT", "PLE", ] ignore = [ "B006", # Do not use mutable data structures for argument defaults "B008", # Do not perform function call in argument defaults "B905", # Strict zip() usage "C420", # Unnecessary dict comprehension for iterable; use `dict.fromkeys` instead "C408", # Unnecessary `dict` call "I001", # Import block is un-sorted or un-formatted "E402", # Module level import not at top of file "E501", # Line too long "E721", # Do not compare types, use `isinstance()` "E731", # Do not assign a `lambda` expression, use a `def` "E741", # Ambiguous variable name "F401", # Imported by unused "EXE001", # file with shebang is not marked executable "NPY002", # replace legacy random "RUF005", # Consider iterable unpacking "RUF008", # Do not use mutable default values for dataclass "RUF010", # Use explicit conversion flag "RUF012", # Mutable class attributes "RUF013", # PEP 484 prohibits implicit `Optional` "RUF015", # Prefer `next(...)` over single element slice "RUF046", # Value being cast to `int` is already an integer "RUF059", # Unpacked variables are not used "RUF051", # Prefer pop over del ] fixable = ["ALL"] unfixable = [] dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [format] quote-style = "double" indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" docstring-code-format = false [lint.mccabe] max-complexity = 150 ================================================ FILE: CHANGELOG.md ================================================ # Change Log for SD.Next ## Update for 2026-02-07 - **Upscalers** - add support for [spandrel](https://github.com/chaiNNer-org/spandrel) upscaling engine with suport for new upscaling model families - add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2* - add two new interpolation methods: *HQX* and *ICB* - **Features** - pipelines: add **ZImageInpaint**, thanks @CalamitousFelicitousness - add `--remote` command line flag that reduces client/server chatter and improves link stability for long-running generates, useful when running on remote servers - **UI** - ui: **themes** add *CTD-NT64Light* and *CTD-NT64Dark*, thanks @resonantsky - ui: **gallery** add option to auto-refresh gallery, thanks @awsr - **Internal** - refactor: reorganize `cli` scripts - **Fixes** - fix: add metadata restore to always-on scripts - fix: improve wildcard weights parsing, thanks @Tillerz - fix: ui gallery cace recursive cleanup, thanks @awsr - fix: `anima` model detection - fix: lora unwanted unload - fix: improve preview error handler ## Update for 2026-02-04 ### Highlights for 2026-02-04 Refresh release two weeks after prior release, yet we still somehow managed to pack in *~150 commits*! Highlights would be two new models: **Z-Image-Base** and **Anima**, *captioning* support for **tagger** models and a massive addition of new **schedulers** Also here are updates to `torch` and additional GPU archs support for `ROCm` backends, plus a lot of internal improvements and fixes. [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2026-02-04 - **Models** - [Tongyi-MAI Z-Image Base](https://tongyi-mai.github.io/Z-Image-blog/) yup, its finally here, the full base model of **Z-Image** - [CircleStone Anima](https://huggingface.co/circlestone-labs/Anima) 2B anime optimized model based on a modified Cosmos-Predict, using Qwen3-0.6B as a text encoder - **Features** - **caption** tab support for Booru tagger models, thanks @CalamitousFelicitousness - add SmilingWolf WD14/WaifuDiffusion tagger models, thanks @CalamitousFelicitousness - support comments in wildcard files, using `#` - support aliases in metadata skip params, thanks @CalamitousFelicitousness - ui gallery improve cache cleanup and add manual option, thanks @awsr - selectable options to add system info to metadata, thanks @Athari see *settings -> image metadata* - **Schedulers** - schedulers documentation has new home: - add 13(!) new scheduler families not a port, but more of inspired-by [res4lyf](https://github.com/ClownsharkBatwing/RES4LYF) library all schedulers should be compatible with both `epsilon` and `flow` prediction style! *note*: each family may have multiple actual schedulers, so the list total is 56(!) new schedulers - core family: *RES* - exponential: *DEIS, ETD, Lawson, ABNorsett* - integrators: *Runge-Kutta, Linear-RK, Specialized-RK, Lobatto, Radau-IIA, Gauss-Legendre* - flow: *PEC, Riemannian, Euclidean, Hyperbolic, Lorentzian, Langevin-Dynamics* - add 3 additional schedulers: *CogXDDIM, DDIMParallel, DDPMParallel* not originally intended to be a general purpose schedulers, but they work quite nicely and produce good results - image metadata: always log scheduler class used - **API** - add `/sdapi/v1/xyz-grid` to enumerate xyz-grid axis options and their choices see `/cli/api-xyzenum.py` for example usage - add `/sdapi/v1/sampler` to get current sampler config - modify `/sdapi/v1/samplers` to enumerate available samplers possible options see `/cli/api-samplers.py` for example usage - **Internal** - tagged release history: each major for the past year is now tagged for easier reference - **torch** update *note*: may cause slow first startup/generate **cuda**: update to `torch==2.10.0` **xpu**: update to `torch==2.10.0` **rocm**: update to `torch==2.10.0` **openvino**: update to `torch==2.10.0` and `openvino==2025.4.1` - rocm: expand available gfx archs, thanks @crashingalexsan - rocm: set `MIOPEN_FIND_MODE=2` by default, thanks @crashingalexsan - relocate all json data files to `data/` folder existing data files are auto-migrated on startup - refactor and improve connection monitor, thanks @awsr - further work on type consistency and type checking, thanks @awsr - log captured exceptions - improve temp folder handling and cleanup - remove torch errors/warings on fast server shutdown - add ui placeholders for future agent-scheduler work, thanks @ryanmeador - implement abort system on repeated errors, thanks @awsr currently used by lora and textual-inversion loaders - update package requirements - **Fixes** - add video ui elem_ids, thanks @ryanmeador - use base steps as-is for non sd/sdxl models - ui css fixes for modernui - support lora inside prompt selector - framepack video save - metadata save for manual saves ## Update for 2026-01-22 Bugfix refresh - add `SD_DEVICE_DEBUG` env variable to trace rocm/xpu/directml init failures - fix detailer double save - fix lora load when using peft/diffusers loader - fix rocm hipblaslt detection - fix image delete, thanks @awsr - fix `all_seeds` error - fix qwen settings typo, thanks @liutyi - improve `wrap_gradio` error handling - use refiner/detail steps as-is for non sd/sdxl models ## Update for 2026-01-20 ### Highlights for 2026-01-20 First release of 2026 brings quite a few new models: **Flux.2-Klein, Qwen-Image-2512, LTX-2-Dev, GLM-Image** There are also improvements to *SDNQ* quantization engine, updated *Prompt Enhance*, *Image Preview* and many others. Plus some significant under-the-hood changes to improve code coverage and quality which resulted in more than usual levels of bug-fixes and some ~330 commits! For full list of changes, see full changelog. [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2026-01-20 - **Models** - [Flux.2 Klein](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) Flux.2-Klein is a new family of compact models from BFL in *4B and 9B sizes* and avaialable as *destilled and base* variants also includes are *sdnq prequantized variants* *note*: 9B variant is [gated](https://vladmandic.github.io/sdnext-docs/Gated/) - [Qwen-Image-2512](https://qwen.ai/blog?id=qwen-image-2512) Qwen-Image successor, significantly reduces the AI-generated look and adds finer natural detailils and improved text rendering available in both *original*, *sdnq-svd prequantized* and *sdnq-dynamic prequantized* variants thanks @CalamitousFelicitousness - [LTX-2 19B Dev](https://ltx.io/model/ltx-2) LTX-2 is a new very large 19B parameter video generation model from Lightricks using Gemma-3 text encoder available for T2I/I2I workflows in original and sdnq prequantized variants *note*: model is very sensitive to input params and will result in errors otherwise - [GLM-Image](https://z.ai/blog/glm-image) GLM-image is a new image generation model that adopts a hybrid autoregressive with diffusion decoder architecture available in both *original* and *sdnq-dynamic prequantized* variants thanks @CalamitousFelicitousness *note*: model requires pre-release versions of `transformers` package: > pip install --upgrade git+https://github.com/huggingface/transformers.git > ./webui.sh --experimental - [Nunchaku Z-Image Turbo](https://huggingface.co/nunchaku-tech/nunchaku-z-image-turbo) nunchaku optimized z-image turbo - **Feaures** - **SDNQ**: add *dynamic* quantization method sdnq can dynamically determine best quantization method for each module layer slower to quantize on-the-fly, but results in better quality with minimal resource usage - **SDNQ** now has *19 int* based and *69 float* based quantization types *note*: not all are exposed via ui purely for simplicity, but all are available via api and scripts - **wildcards**: allow weights, thanks @Tillerz - **sampler**: add laplace beta schedule results in better prompt adherence and smoother infills - **prompt enhance**: improve handling and refresh ui, thanks @CalamitousFelicitousness new models such moondream-3 and xiaomo-mimo add support for *thinking* mode where model can reason about the prompt add support for *vision* processing where prompt enhance can also optionally analyze input image add support for *pre-fill* mode where prompt enhance can continue from existing caption - **chroma**: add inpaint pipeline support - **taesd preview**: support for more models, thanks @alerikaisattera - **image ouput paths**: better handling of relative/absolute paths, thanks @CalamitousFelicitousness - **UI** - kanvas add send-to functionality - kanvas improve support for standardui - improve extensions tab layout and behavior, thanks @awsr - indicate collapsed/hidden sections - persistent panel minimize/maximize state - gallery improve sorting behavior - gallery implement prev/next navigation in full screen viewer, thanks @ryanmeador - **Internal** - **lora** native support by default will now skip text-encoder can be enabled in *settings -> networks* - update core js linting to `eslint9`, thanks @awsr - update modernui js linting to `eslint9`, thanks @awsr - update kanvas js linting to `eslint9`, thanks @awsr - update strong typing checks, thanks @awsr - update reference models previews, thanks @liutyi - update models specs page, thanks @alerikaisattera - sdnq improvements - startup sequence optimizations - rocm/hip/hipblast detection and initialization improvements - zluda detection and initialization improvements - new env variable `SD_VAE_DEFAULT` to force default vae processing - update `nunchaku==1.1.0` - lora switch logic from force-diffusers to allow-native - split `reference.json` - print system env on startup - disable fallback on models with custom loaders - refactor triggering of prompt parser and set secondary prompts when needed - refactor handling of seeds - allow unsafe ssl context for downloads - **Fixes** - controlnet: controlnet with non-english ui locales - core: add skip_keys to offloading logic, fixes wan frames mismatch, thanks @ryanmeador - core: force model move on offload=none - core: hidiffusion tracing - core: hip device name detection - core: reduce triton test verbosity - core: switch processing class not restoring params - extension tab: update checker, date handling, formatting etc., thanks @awsr - lora force unapply on change - lora handle null description, thanks @CalamitousFelicitousness - lora loading when using torch without distributed support - lora skip with strength zero - lora: generate slowdown when consequtive lora-diffusers enabled - model: google-genai auth, thanks @CalamitousFelicitousness - model: improve qwen i2i handling - model: kandinsky-5 image and video on non-cuda platforms - model: meituan-longca-image-edit missing image param - model: wan 2.2 i2v - model: z-image single-file loader - other: update civitai base models, thanks @trojaner - ui: gallery save/delete - ui: mobile auto-collapse when using side panel, thanks @awsr - ui: networks filter by model type - ui: networks icon/list view type switch, thanks @awsr - vae: force align width/height to vae scale factor - wildards with folder specification ## Update for 2025-12-26 ### Highlights for 2025-12-26 End of year release update, just two weeks after previous one, with several new models and features: - Several new models including highly anticipated **Qwen-Image-Edit 2511** as well as **Qwen-Image-Layered**, **LongCat Image** and **Ovis Image** - New features including support for **Z-Image** *ControlNets* and *fine-tunes* and **Detailer** segmentation support [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2025-12-26 - **Models** - [LongCat Image](https://github.com/meituan-longcat/LongCat-Image) in *Image* and *Image Edit* variants LongCat is a new 8B diffusion base model using Qwen-2.5 as text encoder - [Qwen-Image-Edit 2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability - [Qwen-Image-Layered](https://huggingface.co/Qwen/Qwen-Image-Layered) in *base* and *pre-quantized* variants Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers *note*: set number of desired output layers in *settings -> model options* - [Ovis Image 7B](https://huggingface.co/AIDC-AI/Ovis-Image-7B) Ovis Image is a new text-to-image base model based on Qwen3 text-encoder and optimized for text-rendering - **Features** - Google **Gemini** and **Veo** models support for both *Dev* and *Vertex* access methods see [docs](https://vladmandic.github.io/sdnext-docs/Google-GenAI/) for details - **Z-Image Turbo** support loading transformer file-tunes in safetensors format as with any transformers/unet finetunes, place them then `models/unet` and use **UNET Model** to load safetensors file as they are not complete models - **Z-Image Turbo** support for **ControlNet Union** includes 1.0, 2.0 and 2.1 variants - **Detailer** support for segmentation models some detection models can produce exact segmentation mask and not just box to enable, set `use segmentation` option added segmentation models: *anzhc-eyes-seg*, *anzhc-face-1024-seg-8n*, *anzhc-head-seg-8n* - **Internal** - update nightlies to `rocm==7.1` - mark `python==3.9` as deprecated - extensions improved status indicators, thanks @awsr - additional type-safety checks, thanks @awsr - add model info to ui overlay - **Wiki/Docs/Illustrations** - update models page, thanks @alerikaisattera - update reference models samples, thanks @liutyi - **Fixes** - generate forever fix loop checks, thanks @awsr - tokenizer expclit use for flux2, thanks @CalamitousFelicitousness - torch.compile skip offloading steps - kanvas css with standardui - control input media with non-english locales - handle embeds when on meta device - improve offloading when model has manual modules - ui section colapsible state, thanks @awsr - ui filter by model type ## Update for 2025-12-11 ### Highlights for 2025-12-11 *What's new?* New native [kanvas](https://vladmandic.github.io/sdnext-docs/Kanvas/) module for image manipulation that fully replaces *img2img*, *inpaint* and *outpaint* controls, massive update to **Captioning/VQA** models and features New generation of **Flux.2** large image model, new **Z-Image** model that is creating a lot of buzz, new **Kandinsky 5 Lite** image model and new **Photoroom PRX** model And first cloud models with **Google Nano Banana** *2.5 Flash and 3.0 Pro* and **Google Veo** *3.1* video model Also new are **HunyuanVideo 1.5** and **Kandinsky 5 Pro** video models Plus a lot of internal improvements and fixes ![Screenshot](https://github.com/user-attachments/assets/54b25586-b611-4d70-a28f-ee3360944034) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2025-12-11 - **Models** - [Black Forest Labs FLUX.2 Dev](https://bfl.ai/blog/flux-2) and prequantized variation [SDNQ-SVD-Uint4](https://huggingface.co/Disty0/FLUX.2-dev-SDNQ-uint4-svd-r32) **FLUX.2-Dev** is a brand new model from BFL and uses large 32B DiT together with Mistral 24B as text encoder model is available for text, image and edit tasks and can optionally use control input as second input image this is a very large model at ~100GB, so use of prequantized model at ~32GB is strongly advised using prequant version and default offloading, model runs on GPUs with ~20GB *note*: model is [gated](https://vladmandic.github.io/sdnext-docs/Gated/) - [Z-Image Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) and prequantized variation [SDNQ-SVD-Uint4](https://huggingface.co/Disty0/Z-Image-Turbo-SDNQ-uint4-svd-r32) **Z-Image** is a powerful and highly efficient image generation model with 6B parameters and using Qwen-3 as text encoder unlike most of new models that are far larger, Z-Image architecture allows it to run with good performance even on mid-range hardware *note*: initial release is *Turbo* variant only with *Base* and *Edit* variants to follow - [Kandinsky 5.0 Lite]() is a new 6B model using Qwen-2.5 as text encoder it comes in text-to-image and image-edit variants - **Google Gemini Nano Banana** [2.5 Flash](https://blog.google/products/gemini/gemini-nano-banana-examples/) and [3.0 Pro](https://deepmind.google/models/gemini-image/pro/) first cloud-based model directly supported in SD.Next UI *note*: need to set `GOOGLE_API_KEY` environment variable with your key to use this model - [Photoroom PRX 1024 Beta](https://huggingface.co/Photoroom/prx-1024-t2i-beta) PRX (Photoroom Experimental) is a small 1.3B parameter t2i model trained entirely from scratch, it uses T5-Gemma text-encoder - **Video** - [HunyuanVideo 1.5](https://huggingface.co/tencent/HunyuanVideo-1.5) in T2V and I2V variants, both standard and distilled and both 720p and 480p resolutions **HunyuanVideo 1.5** improves upon previous 1.0 version with better quality and higher resolution outputs, it uses Qwen2.5-VL text-encoder distilled variants provide faster generation with slightly reduced quality - [Kandinsky 5.0 Pro Video](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers) in T2V and I2V variants larger 19B (and more powerful version) of previously released Lite 2B models - [Google Veo 3.1](https://gemini.google/us/overview/video-generation/) for T2V and I2V workflows *note*: need to set `GOOGLE_API_KEY` environment variable with your key to use this model - **Kanvas**: new module for native canvas-based image manipulation kanvas is a full replacement for *img2img, inpaint and outpaint* controls see [docs](https://vladmandic.github.io/sdnext-docs/Kanvas/) for details *experimental*: report any feedback in master [issue](https://github.com/vladmandic/sdnext/issues/4358) - **Captioning** and **VQA: Visual Question & Answer** massive update to both features and supported models, thanks @CalamitousFelicitousness models: - additional `mooondream-2` features - support for `moondream-3-preview` - support for `qwen3-vl` with thinking - additional `gemma-3-vl` finetunes - support for `XiaomiMiMo` ui: - ability to annotate actual image, not just generate captions/answers e.g. actualy mark detected regions/points features: - ui indicator of model capabilities - support for *prefill* style of prompting/answering - support for *reasoning* mode for supported models with option to output answer-only or reasoning-process - additional debug logging - **Other Features** - **wildcards**: allow recursive inline wildcards using curly braces syntax - **sdnq**: simplify pre-quantization saved config - **attention**: additional torch attention settings - **lora**: separate fuse setting for native-vs-diffuser implementations - **auth**: strong-enforce auth check on all api endpoints - **amdgpu**: prefer rocm-on-windows over zluda - **amdgpu**: improve rocm-on-windows installer - **sdnq**: improve dequant logic - **gallery**: significant performance improvements, thanks @awsr - **API** - `/control` endpoint is now fully compatible with scripts - `/control` additional params to to control *xyz grid* see `cli/api-xyz.py` for simple example - `/detailers` new endpoint to list available detailers, both built-in and any custom downloaded - `/face-restorers` expanded to list model folders - **Internal** - python: set 3.10 as minimum supported version - sdnq: multiple improvements to quantization and dequantization logic - torch: update to `torch==2.9.1` for *cuda, ipex, openvino, rocm* backends - attention: refactor attention handling - scripts: remove obsolete video scripts - lint: update global lint rules - chrono: switch to official pipeline - pipeline: add optional preprocess and postprocess hooks - auth: wrap all internal api calls with auth check and use token when possible - installer: reduce requirements - installer: auto-restart on self-update - server: set correct mime-types - sdnq: unconditional register on startup - python: start work on future-proofing for modern python versions, thanks @awsr - nunchaku: update to `1.0.2` - lint: add rules for run-on-windows - gallery: setting to enable/disable client-side caching, thanks @awsr - gallery: faster thumbnail generation, thanks @awsr - gallery: purge old thumbnails, thanks @awsr - **Docs** - update supported models table with VAE information, thanks @alerikaisattera - **Fixes** - xyz-grid: improve parsing of axis lists, thanks @awsr - hires: strength save/load in metadata, thanks @awsr - imgi2img: fix initial scale tab, thanks @awsr - img2img: fix restoring refine sampler from metadata, thanks @awsr - log: client log formatting, thanks @awsr - rocm: check if installed before forcing install - pony-v7: fix text-encoder - detailer: with face-restorers - detailer: using lora in detailer prompt - detailer: fail on unsupported models instead of corrputing results - ui: fix collapsible panels - svd: fix stable-video-diffusion dtype mismatch - animatediff: disable sdnq if used - lora: restore pipeline type if reload/recompile needed - process: improve send-to functionality - control: safe load non-sparse controlnet - control: fix marigold preprocessor with bfloat16 - auth: fix password being shown in clear text during login - firefox: remove obsolete checks, thanks @awsr - runai streamer: cleanup logging, thanks @CalamitousFelicitousness - gradio: event handlers, thanks @awsr ## Update for 2025-11-06 ### Highlights for 2025-11-06 Service pack release that handles critical issues and improvements for **ROCm-on-Windows** and **ZLUDA** backends Also included are several new features, notably improvements to **detailer** and ability to run [SD.Next](https://github.com/vladmandic/sdnext) with specific modules disabled And new video model, **nVidia SANA 2B** ![Screenshot](https://github.com/user-attachments/assets/d6119a63-6ee5-4597-95f6-29ed0701d3b5) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2025-11-06 - **Models** - [SANA Video_2B_480p T2V](https://huggingface.co/Efficient-Large-Model/SANA-Video_2B_480p_diffusers) is a small 2B ultra-efficient diffusion model designed for rapid generation of high-quality videos and uses Gemma2 text encoder - **Features** - **ROCm for Windows** switch to using **TheRock** `torch` builds when available recommended to run: `webui --use-rocm --reinstall` - **ZLUDA** improve detection and handling of unsupported GPUs recommended to run: `webui --use-zluda --reinstall` - **detailer** optional include detection image to output results optional sort detection objects left-to-right for improved prompt consistency enable multi-subject and multi-model prompts - **disable modules** ability to disable parts of the app useful for custom deployments where some features are not desired *note*: this doesn't just hide it from user, it completely disables the code paths use `--disable x,y,z` possible values: - main tabs: *control,txt2img,img2img,video,extras,caption,gallery* - aside tabs: *extensions,models,info,update,history,monitor,onnx,system,networks,logs* - special: *settings,config* (hidden instead of disabled) - **wildcards**: add inline processing using curly braces syntax - add setting to control `cudnn` enable/disable *note*: this can also be used to enable/disable `MIOpen` on ROCm backends - change `vlm` beams to 1 by default for faster response - **controlnet** allow processor to keep aspect-ratio for override images based on i2i or t2i resolution - **networks** info details now displays image metadata from preview image - **networks** new model previews, thanks @liutyi - **Fixes** - zluda: test and disable MIOpen as needed - qwen: improve lora compatibility - chrono: transformers handling - chrono: extract last frame - chrono: add vae scale override, thanks @CalamitousFelicitousness - runai: improve streamer integration - transformers: `dtype` use new syntax - rocm: possible endless loop during hip detection - rocm: auto-disable `miopen` for gfx120x - detailer: better handling of settings, thanks @awsr - installer: cleanup `--optional` - hires: guard against multi-controlnet - inpaint: fix init - version: detection when cloned with .git suffix, thanks @awsr - sdnq: init on video model load - model type: detection - model type: add tracing to model detection - settings: guard against non-string values, thanks @awsr - ui: wait for server options to be ready before initializing ui - ui: fix full-screen image viewer buttons with non-standard ui theme - ui: control tab show override section - ui: mobile layout for video tab - ui: increase init timeout - video: save to subfolder - taesd: warn on long decode times - metadata: keep exif on thumbnail generation - wildcard: obey seed for reproducible results - sageattention: handle possible triton issues on some nvidia gpus, thanks @CalamitousFelicitousness ## Update for 2025-10-31 ### Highlights for 2025-10-31 Less than 2 weeks since last release, here's a service-pack style update with a lot of fixes and improvements: - Reorganization of **Reference Models** into *Base, Quantized, Distilled and Community* sections for easier navigation and introduction of optimized **pre-quantized** variants for many popular models - use this as your quick start! - New models: **HunyuanImage 2.1** capable of 2K images natively, **HunyuanImage 3.0** large unified multimodal autoregressive model, **ChronoEdit** that re-purposes temporal consistency of generation for image editing **Pony 7** based on AuraFlow architecture, **Kandinsky 5** 10s video models - New **offline mode** to use previously downloaded models without internet connection - Optimizations to **WAN-2.2** given its popularity plus addition of native **VAE Upscaler** and optimized **pre-quantized** variants - New SOTA model loader using **Run:ai streamer** - Updates to `rocm` and `xpu` backends - Fixes, fixes, fixes... too many to list here! ![Screenshot](https://github.com/user-attachments/assets/d6119a63-6ee5-4597-95f6-29ed0701d3b5) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2025-10-31 - **Reference** networks section is now split into actual *Base* models plus: - **Quantized**: pre-quantized variants of the base models using SDNQ-SVD quantization for optimal quality and smallest possible resource usage examples: *FLUX.1-Dev/Krea/Kontext/Schnell, Qwen-Image/Edit/2509, Chroma1-HD, WAN-2.2-A44B, etc.* *note*: pre-quantized *WAN-2.2-14B* is also available in video models and runs with only 12GB VRAM! - **Distilled**: distilled variants of base models examples: *Turbo, Lightning, Lite, SRPO, Distill, Pruning, etc.* - **Community**: community highlights examples: *Tempest, Juggernaut, Illustrious, Pony, NoobAI, etc.* and all reference models have new preview images, thanks @liutyi - **Models Reference** - [Tencent HunyuanImage 2.1](https://huggingface.co/tencent/HunyuanImage-2.1) in *full*, *distilled* and *refiner* variants *HunyuanImage-2.1* is a large (51GB) T2I model capable of natively generating 2K images and uses Qwen2.5 + T5 text-encoders and 32x VAE - [Tencent HunyuanImage 3.0](https://huggingface.co/tencent/HunyuanImage-3.0) in [pre-quant](https://huggingface.co/Disty0/HunyuanImage3-SDNQ-uint4-svd-r32) only variant due to massive size *HunyuanImage 3.0* is very large at 47GB pre-quantized (oherwise its 157GB) that unifies multimodal understanding and generation within an autoregressive framework - [nVidia ChronoEdit](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers) *ChronoEdit* is a 14B image editing model based on *WAN* this model reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency to extend temporal consistency for image editing, set *settings -> model options -> chrono temporal steps* to desired number of temporaly reasoning steps - [Kandinsky 5 Lite 10s](https://huggingface.co/ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers') in *SFT, CFG-distilled and Steps-distilled* variants second series of models in *Kandinsky5* series is T2V model optimized for 10sec videos and uses Qwen2.5 text encoder - [Pony 7](https://huggingface.co/purplesmartai/pony-v7-base) Pony 7 steps in a different direction from previous Pony models and is based on AuraFlow architecture and UMT5 encoder - **Models Auxiliary** - [Qwen 3-VL](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct) VLM for interrogate and prompt enhance, thanks @CalamitousFelicitousness this includes *2B, 4B and 8B* variants - [WAN Asymettric Upscale](https://huggingface.co/spacepxl/Wan2.1-VAE-upscale2x) available as general purpose upscaler that can be used during standard workflow or process tab available as VAE for compatible video models: *WAN-2.x-14B, SkyReels-v2* models - [Apple DepthPro](https://huggingface.co/apple/DepthPro) controlnet processor, thanks @nolbert82 - [LibreFlux controlnet](https://huggingface.co/neuralvfx/LibreFlux-ControlNet) segmentation controlnet for FLUX.1 - **Features** - **offline mode**: enable in *settings -> hugginface* enables fully offline mode where previously downloaded models can be used as-is *note*: must be enabled only after all packages have been installed and model has been run online at least once - **model load**: SOTA method using nVidia's [Run:ai streamer](https://github.com/run-ai/runai-model-streamer) enable in *settings -> model options -> runai streamer* applies to *diffusers, transformers and sdnq* loaders, note this is linux-only feature *experimental* but shows significant model load speedups, 20-40% depending on model and hardware - **Backend** - switch to `torch==2.9` for *ipex, rocm and openvino* - switch to `rocm==7.0` for nightlies - log `triton` availability on startup - add `xpu` stats in gpu monitor - **Other** - improved **SDNQ SVD** and low-bit matmul performance - reduce RAM usage on model load using **SDNQ SVD** - change default **schedulers** for sdxl - warn on `python==3.9` end-of-life and `python==3.10` not actively supported - **scheduler** add base and max shift parameters for flow-matching samplers - enhance `--optional` flag to pre-install optional packages - add `[lora]` to recognized filename patterns - when using **shared-t5** *(default)*, it will load standard or pre-quant depending on model - enhanced LoRA support for **Wan-2.2-14B** - log available attention mechanisms on startup - support for switching back-and-forth **t2i** and **t2v** for *wan-2.x* models - control `api` cache controlnets - additional model modules **deduplication** for both normal and pre-quant models: *umt5, qwen25-vl* - **Fixes** - startup error with `--profile` enabled if using `--skip` - restore orig init image for each batch sequence - fix modernui hints layout - fix `wan-2.2-a14b` stage selection - fix `wan-2.2-5b` vae decode - disabling live preview should not disable progress updates - video tab create `params.txt` with metadata - fix full-screen image-viewer toolbar actions with control tab - improve filename sanitization - lora auto-detect low/high stage if not specified - lora disable fuse on partially applied network - fix networks display with extended characters, thanks @awsr - installer handle different `opencv` package variants - fix using pre-quantized shared-t5 - fix `wan-2.2-14b-vace` single-stage exectution - fix `wan-2.2-5b` tiled vae decode - fix `controlnet` loading with quantization - video use pre-quantized text-encoder if selected model is pre-quantized - handle sparse `controlnet` models - catch `xet` warnings - avoid unnecessary pipe variant switching - validate pipelines on import - fix `nudenet` process tab operations - `controlnet` input validation - log metadata keys that cannot be applied - fix `framepack` with image input ## Update for 2025-10-18 - **Models** [Kandinsky 5 Lite](https://huggingface.co/ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers') in *SFT, CFG-distilled and Steps-distilled* variants first model in Kandinsky5 series is T2V model optimized for 5sec videos and uses Qwen2.5 text encoder - **Fixes** - ROCm-on-Windows additional checks - SDNQ-SVD fallback on incompatible layers - Huggingface model download - Video implement dynamic and manual sampler shift - Fix interrupt batch processing - Delay import of control processors until used - Fix tiny VAE with batched results - Fix CFG scale not added to metadata and set valid range to >=1.0 - **Other** - Optimized Video tab layout - Video enable VAE slicing and framewise decoding when possible - Detect and log `flash-attn` and `sageattention` if installed - Remove unused UI settings ## Update for 2025-10-17 ### Highlights for 2025-10-17 It's been a month since the last release and number of changes is yet again massive with over 300 commits! Highlight are: - **Torch**: ROCm on Windows for AMD GPUs if you have a compatible GPU, performance gains are significant! - **Models**: a lot of new stuff with **Qwen-Image-Edit** including multi-image edits and distilled variants, new **Flux**, **WAN**, **LTX**, **HiDream** variants, expanded **Nunchaku** support and new SOTA upscaler with **SeedVR2** plus improved video support in general, including new methods of video encoding - **Quantization**: new **SVD**-style quantization using SDNQ offers almost zero-loss even with **4bit** quantization and now you can also test your favorite quantization on-the-fly and then save/load model for future use - Other: support for **Huggingface** mirrors, changes to installer to prevent unwanted `torch-cpu` operations, improved VAE previews, etc. ![Screenshot](https://github.com/user-attachments/assets/d6119a63-6ee5-4597-95f6-29ed0701d3b5) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2025-10-17 - **Models** - [WAN 2.2 14B VACE](https://huggingface.co/alibaba-pai/Wan2.2-VACE-Fun-A14B) available for *text-to-image* and *text-to-video* and *image-to-video* workflows - [Qwen Image Edit 2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) and [Nunchaku Qwen Image Edit 2509](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit-2509) updated version of Qwen Image Edit with improved image consistency - [Qwen Image Pruning](https://huggingface.co/OPPOer/Qwen-Image-Pruning) and [Qwen Image Edit Pruning](https://huggingface.co/OPPOer/Qwen-Image-Edit-Pruning) pruned versions of Qwen with 13B params instead of 20B, with some quality tradeoff - [Tencent FLUX.1 Dev SRPO](https://huggingface.co/tencent/SRPO) SRPO is trained by Tencent with specific technique: directly aligning the full diffusion trajectory with fine-grained human preference - [Nunchaku SDXL](https://huggingface.co/nunchaku-tech/nunchaku-sdxl) and [Nunchaku SDXL Turbo](https://huggingface.co/nunchaku-tech/nunchaku-sdxl-turbo) impact of nunchaku engine on unet-based model such as sdxl is much less than on a dit-based models, but its still significantly faster than baseline note that nunchaku optimized and pre-quantized unet is replacement for base unet, so its only applicable to base models, not any of fine-tunes *how to use*: enable nunchaku in settings -> quantization and then load either sdxl-base or sdxl-base-turbo reference models - [HiDream E1.1](https://huggingface.co/HiDream-ai/HiDream-E1-1) updated version of HiDream-E1 image editing model - [LTXVideo 0.9.8](https://huggingface.co/Lightricks/LTX-Video-0.9.8-13B-distilled) updated version of LTXVideo t2v/i2iv model - [SeedVR2](https://iceclear.github.io/projects/seedvr/) originally designed for video restoration, seedvr works great for image detailing and upscaling! available in 3B, 7B and 7B-sharp variants, use as any other upscaler! note: seedvr is a very large model (6.4GB and 16GB respectively) and not designed for lower-end hardware, quantization is highly recommended note: seedvr is highly sensitive to its cfg scale, set in *settings -> postprocessing* lower values will result in smoother output while higher values add details - [X-Omni SFT](https://x-omni-team.github.io/) *experimental*: X-omni is a transformer-only discrete auto-regressive image generative model trained with reinforcement learning - **Features** - **Model save**: ability to save currently loaded model as a new standalone model why? SD.Next always prefers to start with full model and quantize on-demand during load however, when you find your exact preferred quantization settings that work well for you, saving such model as a new model allows for faster loads and reduced disk space usage so its best of both worlds: you can experiment and test different quantization methods and once you find the one that works for you, save it as a new model saved models appear in network tab as normal models and can be loaded as such available in *models* tab - [Qwen Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) multi-image editing requires qwen-image-edit-2509 or its variant as multi-image edits are not available in original qwen-image in ui control tab: inputs -> separate init image add image for *input media* and *control media* can be - [Cache-DiT](https://github.com/vipshop/cache-dit) cache-dit is a unified, flexible and training-free cache acceleration framework compatible with many dit-based models such as FLUX.1, Qwen, HunyuanImage, Wan2.2, Chroma, etc. enable in *settings -> pipeline modifiers -> cache-dit* - [Nunchaku Flux.1 PulID](https://nunchaku.tech/docs/nunchaku/python_api/nunchaku.pipeline.pipeline_flux_pulid.html) automatically enabled if loaded model is FLUX.1 with Nunchaku engine enabled and when PulID script is enabled - **Huggingface mirror** in *settings -> huggingface* if you're working from location with limited access to huggingface, you can now specify a mirror site for example enter, `https://hf-mirror.com` - **Compute** - **ROCm** for Windows support for both official torch preview release of `torch-rocm` for windows and **TheRock** unofficial `torch-rocm` builds for windows note that rocm for windows is still in preview and has limited gpu support, please check rocm docs for details - **DirectML** warn as *end-of-life* `torch-directml` received no updates in over 1 year and its currently superseded by `rocm` or `zluda` - command line params `--use-zluda` and `--use-rocm` will attempt desired operation or fail if not possible previously sdnext was performing a fallback to `torch-cpu` which is not desired - **installer** if `--use-cuda` or `--use-rocm` are specified and `torch-cpu` is installed, installer will attempt to reinstall correct torch package - **installer** warn if *cuda* or *rocm* are available and `torch-cpu` is installed - support for `torch==2.10-nightly` with `cuda==13.0` - **Extensions** - [Agent-Scheduler](https://github.com/SipherAGI/sd-webui-agent-scheduler) was a high-value built-in extension, but it has not been maintained for 1.5 years it also does not work with control and video tabs which are the core of sdnext nowadays so it has been removed from built-in extensions: manual installation is still possible - [DAAM: Diffusion Attentive Attribution Maps](https://github.com/castorini/daam) create heatmap visualizations of which parts of the prompt influenced which parts of the image available in scripts for sdxl text-to-image workflows - **Offloading** - improve offloading for pipelines with multiple stages such as *wan-2.2-14b* - add timers to measure onload/offload times during generate - experimental offloading using `torch.streams` enable in settings -> model offloading - new feature to specify which models types not to offload in *settings -> model offloading -> model types not to offload* - **UI** - **connection monitor** main logo in top-left corner now indicates server connection status and hovering over it shows connection details - separate guidance and detail sections - networks ability to filter lora by base model version - add interrogate button to input images - disable spellchecks on all text inputs - **SDNQ** - add `SVDQuant` quantization method support - make sdnq scales compatible with balanced offload - add int8 `matmul` support for RDNA2 GPUs via triton - improve int8 `matmul` performance on Intel GPUs - **Other** - server will note when restart is recommended due to package updates - **interrupt** will now show last known preview image *keep incomplete* setting is now *save interrupted* - **logging** enable `debug`, `docs` and `api-docs` by default - **logging** add detailed ram/vram utilization info to log logging frequency can be specified using `--monitor x` command line param, where x is number of seconds - **ipex** simplify internal implementation - refactor to use new libraries - styles and wildcards now use same seed as main generate for reproducible results - **api** new endpoint POST `/sdapi/v1/civitai` to trigger civitai models metadata update accepts optional `page` parameter to search specific networks page - **reference models** additional example images, thanks @liutyi - **reference models** add model size and release date, thanks @alerikaisattera - **video** support for configurable multi-stage models such as WAN-2.2-14B - **video** new LTX model selection - replace `pynvml` with `nvidia-ml-py` for gpu monitoring - update **loopback** script with radon seed option, thanks @rabanti - **vae** slicing enable for *lowvram/medvram*, tiling for *lowvram*, both disabled otherwise - **attention** remove split-attention and add explicitly attention slicing enable/disable option enable in *settings -> compute settings* can be combined with sdp, enabling may improve stability when used on iGPU or shared memory systems - **nunchaku** update to `1.0.1` and enhance installer - **xyz-grid** add guidance section - **preview** implement configurable layers for WAN, Qwen, HV - update swagger `/docs` endpoint style - add `[epoch]` to filename template - starting `[seq]` for filename template is now higher of largest previous sequence or number of files in folder - **Video** - use shared **T5** text encoder for video models when possible - use shared **LLama** text encoder for video models when possible - unified video save code across all video models also avoids creation of temporary files for each frame unless user wants to save them - unified prompt enhance code across all video models - add job state tracking for video generation - fix quantization not being applied on load for some models - improve offloading for **ltx** and **wan** - fix model selection in **ltx** tab - **Experimental** - `new` command line flag enables new `pydantic` and `albumentations` packages - **modular pipelines**: enable in *settings -> model options* only compatible with some pipelines, invalidates preview generation - **modular guiders**: automatically used for compatible pipelines when *modular pipelines* is enabled allows for using many different guidance methods: *CFG, CFGZero, PAG, APG, SLG, SEG, TCFG, FDG* - **Wiki** - updates to *AMD-ROCm, ZLUDA, LoRA, DirectML, SDNQ, Quantization, Prompting, LoRA* pages - new *Stability-Matrix* page - **Fixes** - **Microsoft Florence 2** both base and large variants *note* this will trigger download of the new variant of the model, feel free to delete older variant in `huggingface` folder - **MiaoshouAI PromptGen** 1.5/2.0 in both base and large variants - fix prompt scheduling, thanks @nolbert82 - ui: fix image metadata display when switching selected image in control tab - framepack: add explicit hf-login before framepack load - framepack: patch solver for unsupported gpus - benchmark: remove forced sampler from system info benchmark - xyz-grid: fix xyz grid with random seeds - reference: fix download for sd15/sdxl reference models - fix checks in init/mask image decode - fix hf token with extra chars - image viewer refocus on gallery after returning from full screen mode - fix attention guidance metadata save/restore - vae preview add explicity cuda.sync ## Update for 2025-09-15 ### Highlights for 2025-09-15 *What's new*? Big one is that we're (*finally*) switching the default UI to **ModernUI**, for both desktop and mobile use! **StandardUI** is still available and can be selected in settings, but ModernUI is now the default for new installs *What's else*? **Chroma** is in its final form, there are several new **Qwen-Image** variants and **Nunchaku** hit version 1.0! Also, there are quite a few offloading improvements and many quality-of-life changes to UI and overall workflows And check out new **history** tab in the right panel, it now shows visualization of entire processing timeline! ![Screenshot](https://github.com/user-attachments/assets/d6119a63-6ee5-4597-95f6-29ed0701d3b5) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic) ### Details for 2025-09-15 - **Models** - **Chroma** final versions: [Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD), [Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) and [Chroma1-Flash](https://huggingface.co/lodestones/Chroma1-Flash) - **Qwen-Image** [InstantX ControlNet Union](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union) support *note* qwen-image is already a very large model and controlnet adds 3.5GB on top of that so quantization and offloading are highly recommended! - [Qwen-Lightning-Edit](https://huggingface.co/vladmandic/Qwen-Lightning-Edit) and [Qwen-Image-Distill](https://huggingface.co/SahilCarterr/Qwen-Image-Distill-Full) variants - **Nunchaku** variants of [Qwen-Image-Lightning](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image), [Qwen-Image-Edit](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit), [Nunchaku-Qwen-Image-Edit-Lightning](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image-edit) - **Nunchaku** variant of [Flux.1-Krea-Dev](https://huggingface.co/nunchaku-tech/nunchaku-flux.1-krea-dev) if you have a compatible nVidia GPU, Nunchaku is the fastest quantization & inference engine - [HunyuanDiT ControlNet](https://huggingface.co/Tencent-Hunyuan/HYDiT-ControlNet-v1.2) Canny, Depth, Pose - [KBlueLeaf/HDM-xut-340M-anime](https://huggingface.co/KBlueLeaf/HDM-xut-340M-anime) highly experimental: HDM *Home-made-Diffusion-Model* is a project to investigate specialized training recipe/scheme for pre-training T2I model at home based on super-light architecture *requires*: generator=cpu, dtype=float16, offload=none, both positive and negative prompts are required and must be long & detailed - [Apple FastVLM](https://huggingface.co/apple/FastVLM-0.5B) in 0.5B, 1.5B and 7B variants available in captioning tab - updated [SD.Next Model Samples Gallery](https://vladmandic.github.io/sd-samples/compare.html) - **UI** - default to **ModernUI** standard ui is still available via *settings -> user interface -> theme type* - mobile-friendly! - new **History** section in the right panel shows detailed job history plus timeline of the execution - make hints touch-friendly: hold touch to display hint - improved image scaling in img2img and control interfaces - add base model type to networks display, thanks @Artheriax - additional hints to ui, thanks @Artheriax - add video support to gallery, thanks @CalamitousFelicitousness - additional artwork for reference models in networks, thanks @liutyi - improve ui hints display - restyled all toolbuttons to be modernui native - reordered system settings - dynamic direction of dropdowns - improve process tab layout - improve detection of active tab - configurable horizontal vs vertical panel layout in settings -> user interface -> panel min width *example*: if panel width is less than specified value, layout switches to vertical - configurable grid images size in *settings -> user interface -> grid image size* - gallery now includes reference model images - reference models now include indicator if they are *ready* or *need download* - **Offloading** - **balanced** - enable offload during pre-forward by default - improve offloading of models with multiple dits - improve offloading of models with impliciy vae processing - improve offloading of models with controlnet - more aggressive offloading of controlnet with lowvram flag - **group** - new offloading method, using *type=leaf* works on a similar level as sequential offloading and can present significant savings on low-vram gpus, but comes at the higher performance cost - **Quantization** - option to specify models types not to quantize: *settings -> quantization* allows for having quantization enabled, but skipping specific model types that do not need it *example*: `sd, sdxl` - **sdnq** - add quantized matmul support for all quantization types and group sizes - improve the performance of low bit quants - **nunchaku**: update to `nunchaku==1.0.0` *note*: nunchaku updated the repo which will trigger re-download of nunchaku models when first used nunchaku is currently available for: *Flux.1 Dev/Schnell/Kontext/Krea/Depth/Fill*, *Qwen-Image/Qwen-Lightning*, *SANA-1.6B* - **tensorrt**: new quantization engine from nvidia *experimental*: requires new pydantic package which *may* break other things, to enable start sdnext with `--new` flag *note*: this is model quantization only, no support for tensorRT inference yet - **Other** - **LoRA** allow specifying module to apply lora on *example*: `` would apply lora *only* on unet regardless of lora content this is particularly useful when you have multiple loras and you want to apply them on different parts of the model *example*: `` and `` *note*: `low` is shorthand for `module=transformer_2` and `high` is shortcut for `module=transformer` - **Detailer** allow manually setting processing resolution *note*: this does not impact the actual image resolution, only the resolution at which detailer internally operates - refactor reuse-seed and add functionality to all tabs - refactor modernui js codebase - move zluda flash attenion to *Triton Flash attention* option - remove samplers filtering - allow both flow-matching and discrete samplers for sdxl models - cleanup command line parameters - add `--new` command line flag to enable testing of new packages without breaking existing installs - downgrade rocm to `torch==2.7.1` - set the minimum supported rocm version on linux to `rocm==6.0` - disallow `zluda` and `directml` on non-windows platforms - update openvino to `openvino==2025.3.0` - add deprecation warning for `python==3.9` - allow setting denoise strength to 0 in control/img2img this allows to run workflows which only refine or detail existing image without changing it - **Fixes** - normalize path hanlding when deleting images - unified compile upscalers - fix OpenVINO with ControlNet - fix hidden model tags in networks display - fix networks reference models display on windows - fix handling of pre-quantized `flux` models - fix `wan` use correct pipeline for i2v models - fix `qwen-image` with hires - fix `omnigen-2` failure - fix `auraflow` quantization - fix `kandinsky-3` noise - fix `infiniteyou` pipeline offloading - fix `skyreels-v2` image-to-video - fix `flex2` img2img denoising strength - fix `flex2` contronet vs inpaint image selection, thanks @alerikaisattera - fix some use cases with access via reverse-proxy - fix segfault on startup with `rocm==6.4.3` and `torch==2.8` - fix wildcards folders traversal, thanks @dymil - fix zluda flash attention with enable_gqa - fix `wan a14b` quantization - fix reprocess workflow for control with hires - fix samplers set timesteps vs sigmas - fix `detailer` missing metadata - fix `infiniteyou` lora load with ## Update for 2025-08-20 A quick service release with several important hotfixes, improved localization support and adding new **Qwen** model variants... [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) - **Models** - [Qwen-Image-Edit](https://huggingface.co/Qwen/Qwen-Image-Edit) Image editing using natural language prompting, similar to `Flux.1-Kontext`, but based on larger 20B `Qwen-Image` model - [Nunchaku-Qwen-Image](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image) if you have a compatible nVidia GPU, Nunchaku is the fastest quantization engine, currently available for Flux.1, SANA and Qwen-Image models *note*: release version of `nunchaku==0.3.2` does NOT include support, so you need to build [nunchaku](https://nunchaku.tech/docs/nunchaku/installation/installation.html) from source - [SD.Next Model Samples Gallery](https://vladmandic.github.io/sd-samples/compare.html) - updated with new models - **Features** - new *setting -> huggingface -> download method* default is `rust` as new `xet` is known to cause issues - support for `flux.1-kontext` lora - support for `qwen-image` lora - new *setting -> quantization -> modules dtype dict* used to manually override quant types per module - **UI** - new artwork for reference models in networks thanks @liutyi - updated [localization](https://vladmandic.github.io/sdnext-docs/Locale/) for all 8 languages - localization support for ModernUI - single-click on locale rotates current locale double-click on locale resets locale to `en` - exclude ModernUI from list of extensions ModernUI is enabled in settings, not by manually enabling extension - **Docs** - Models and Video pages updated with links to original model repos, model licenses and original release dates thanks @alerikaisattera - **Fixes** - nunchaku use new download links and default to `0.3.2` nunchaku wheels: - fix OpenVINO with offloading - add explicit offload calls on prompt encode - error reporting on model load failure - fix torch version checks - remove extra cache clear - enable explicit sync calls for `rocm` on windows - note if restart-needed on initial startup import error - bypass diffusers-lora-fuse on quantized models - monkey-patch diffusers to use original weights shape when loading lora - guard against null prompt - install `hf_transfter` and `hf_xet` when needed - fix ui cropped network tags - enum reference models on startup - dont report errors if agent scheduler is disabled ## Update for 2025-08-15 ### Highlights for 2025-08-15 New release two weeks after the last one and its a big one with over 150 commits! - Several new models: [Qwen-Image](https://qwenlm.github.io/blog/qwen-image/) (plus *Lightning* variant) and [FLUX.1-Krea-Dev](https://www.krea.ai/blog/flux-krea-open-source-release) - Several updated models: [Chroma](https://huggingface.co/lodestones/Chroma), [SkyReels-V2](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers), [Wan-VACE](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers), [HunyuanDiT](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers-Distilled) - Plus continuing with major **UI** work with new embedded **Docs/Wiki** search, redesigned real-time **hints**, **wildcards** UI selector, built-in **GPU monitor**, **CivitAI** integration and more! - On the compute side, new profiles for high-vram GPUs, offloading improvements, parallel-load for large models, support for new `torch` release and improved quality when using low-bit quantization! - [SD.Next Model Samples Gallery](https://vladmandic.github.io/sd-samples/compare.html): pre-generated image gallery with 60 models (45 base and 15 finetunes) and 40 different styles resulting in 2,400 high resolution images! gallery additionally includes model details such as typical load and inference times as well as sizes and types of each model component (*e.g. unet, transformer, text-encoder, vae*) - And (*as always*) many bugfixes and improvements to existing features! ![sd-samples](https://github.com/user-attachments/assets/3efc8603-0766-4e4e-a4cb-d8c9b13d1e1d) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) *Note*: Change-in-behavior - locations of downloaded HuggingFace models and components are changed to allow for de-duplication of common modules and switched from using system default cache folder to `models/huggingface` SD.Next will warn on startup on unused cache entries that can be removed. Also, to take advantage of de-duplication, you'll need to delete models from your `models/Diffusers` folder and let SD.Next re-download them! ### Details for 2025-08-15 - **Models** - [Qwen-Image](https://qwenlm.github.io/blog/qwen-image/) new image foundational model with *20B* params DiT and using *Qwen2.5-VL-7B* as the text-encoder! available via *networks -> models -> reference* *note*: this model is almost 2x the size of Flux, quantization and offloading are highly recommended! *recommended* params: *steps=50, attention-guidance=4* also available is pre-packaged [Qwen-Lightning](https://huggingface.co/vladmandic/Qwen-Lightning) which is an unofficial merge of [Qwen-Image](https://qwenlm.github.io/blog/qwen-image/) with [Qwen-Lightning-LoRA](https://github.com/ModelTC/Qwen-Image-Lightning/) to improve quality and allow for generating in 8-steps! - [FLUX.1-Krea-Dev](https://www.krea.ai/blog/flux-krea-open-source-release) new 12B base model compatible with FLUX.1-Dev from *Black Forest Labs* with opinionated aesthetics and aesthetic preferences in mind available via *networks -> models -> reference* - [Chroma](https://huggingface.co/lodestones/Chroma) great model based on FLUX.1 and then redesigned and retrained by *lodestones* update with latest **HD**, **HD Flash** and **HD Annealed** variants which are based on *v50* release available via *networks -> models -> reference* - [SkyReels-V2](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers) SkyReels-V2 is a genarative video model based on Wan-2.1 but with heavily modified execution to allow for infinite-length video generation supported variants are: - diffusion-forcing: *T2I DF 1.3B* for 540p videos, *T2I DF 14B* for 720p videos, *I2I DF 14B* for 720p videos - standard: *T2I 14B* for 720p videos and *I2I 14B* for 720p videos - [Wan-VACE](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers) basic support for *Wan 2.1 VACE 1.3B* and *14B* variants optimized support with granular guidance control will follow soon - [HunyuanDiT-Distilled](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers-Distilled) variant of HunyuanDiT with reduced steps and improved performance **Torch** - Set default to `torch==2.8.0` for *CUDA, ROCm and OpenVINO* - Add support for `torch==2.9.0-nightly` - **UI** - new embedded docs/wiki search! **Docs** search: fully-local and works in real-time on all document pages **Wiki** search: uses github api to search online wiki pages - updated real-time hints, thanks @CalamitousFelicitousness - add **Wilcards** UI in networks display - every heading element is collapsible! - quicksettings reset button to restore all quicksettings to default values because things do sometimes get wrong... - configurable image fit in all image views - rewritten **CivitAI downloader** in *models -> civitai* *hint*: you can enter model id in a search bar to pull information on specific model directly *hint*: you can download individual versions or batch-download all-at-once! - redesigned **GPU monitor** - standard-ui: *system -> gpu monitor* - modern-ui: *aside -> console -> gpu monitor* - supported for *nVidia CUDA* and *AMD ROCm* platforms - configurable interval in *settings -> user interface* - updated *models* tab - updated *models -> current* tab - updated *models -> list models* tab - updated *models -> metadata* tab - updated *extensions* tab - redesigned *settings -> user interface* - gallery bypass browser cache for thumbnails - gallery safer delete operation - networks display indicator for currently active items applies to: *styles, loras* - apply privacy blur to hf and civitai tokens - image download will now use actual image filename - increase default and maximum ui request timeout to 2min/5min - *hint*: card layout card layout is used by networks, gallery, civitai search, etc. you can change card size in *settings -> user interface* - **Offloading** - changed **default** values for offloading based on detected gpu memory see [offloading docs](https://vladmandic.github.io/sdnext-docs/Offload/) for details - new feature to specify which modules to offload always or never in *settings -> model offloading -> offload always/never* - new `highvram` profile provides significant performance boost on gpus with more than 24gb - new `offload during pre-forward` option in *settings -> model offloading* switches from explicit offloading to implicit offloading on module execution change - new `diffusers_offload_nonblocking` exerimental setting instructs torch to use non-blocking move operations when possible - **Features** - new `T5: Use shared instance of text encoder` option in *settings -> text encoder* since a lot of new models use T5 text encoder, this option allows to share the same instance across all models without duplicate downloads *note* this will not reduce size of your already downloaded models, but will reduce size of future downloads - **Wan** select which stage to run: *first/second/both* with configurable *boundary ration* when running both stages in settings -> model options - prompt parser allow explict `BOS` and `EOS` tokens in prompt - **Nunchaku** support for *FLUX.1-Fill* and *FLUX.1-Depth* models - update requirements/packages - use model vae scale-factor for image width/heigt calculations - **SDNQ** add `modules_dtype_dict` to quantize *Qwen Image* with mixed dtype - **prompt enhance** add `allura-org/Gemma-3-Glitter-4B`, `Qwen/Qwen3-4B-Instruct-2507`, `Qwen/Qwen2.5-VL-3B-Instruct` model support improve system prompt - **schedulers** add **Flash FlowMatch** - **model loader** add parallel loader option enabled by default, selectable in *settings -> model loading* - **filename namegen** use exact sequence number instead of next available this allows for more predictable and consistent filename generation - **network delete** new feature that allows to delete network from disk in *networks -> show details -> delete* this will also delete description, metadata and previews associated with the network only applicable to safetensors networks, not downloaded diffuser models - **Wiki** - Models page updated with links to original model repos and model licenses, thanks @alerikaisattera - Updated Model-Support with newly supported models - Updated Offload, Prompting, API pages - **API** - add `/sdapi/v1/checkpoint` POST endpoint to simply load a model - add `/sdapi/v1/modules` GET endpoint to get info on model components/modules - all generate endpoints now support `sd_model_checkpoint` parameter this allows to specify which model to use for generation without needing to use additional endpoints - **Refactor** - change default huggingface cache folder from system default to `models/huggingface` sd.next will warn on startup on unused cache entries - new unified pipeline component loader in `pipelines/generic` - remove **LDSR** - remove `api-only` cli option - **Docker** - update cuda base image: `pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime` - update official builds: - **Fixes** - refactor legacy processing loop - fix settings components mismatch - fix *Wan 2.2-5B I2V* workflow - fix *Wan* T2I workflow - fix OpenVINO - fix video model vs pipeline mismatch - fix video generic save frames - fix inpaint image metadata - fix processing image save loop - fix progress bar with refine/detailer - fix api progress reporting endpoint - fix `openvino` backend failing to compile - fix `zluda` with hip-sdk==6.4 - fix `nunchaku` fallback on unsupported model - fix `nunchaku` windows download links - fix *Flux.1-Kontext-Dev* with variable resolution - use `utf_16_be` as primary metadata decoding - fix `sd35` width/height alignment - fix `nudenet` api - fix global state tracking - fix ui tab detection for networks - fix ui checkbox/radio styling for non-default themes - fix loading custom transformers and t5 safetensors tunes - add mtime to reference models - patch torch version so 3rd party libraries can use expected format - unified stat size/mtime calls - reapply offloading on ipadapter load - api set default script-name - avoid forced gc and rely on thresholds - add missing interrogate in output panel ## Update for 2025-07-29 ### Highlights for 2025-07-29 This is a big one: simply looking at number of changes, probably the biggest release since the project started! Feature highlights include: - [ModernUI](https://github.com/user-attachments/assets/6f156154-0b0a-4be2-94f0-979e9f679501) has quite some redesign which should make it more user friendly and easier to navigate plus several new UI themes If you're still using **StandardUI**, give [ModernUI](https://vladmandic.github.io/sdnext-docs/Themes/) a try! - New models such as [WanAI 2.2](https://wan.video/) in 5B and A14B variants for both *text-to-video* and *image-to-video* workflows as well as *text-to-image* workflow! and also [FreePik F-Lite](https://huggingface.co/Freepik/F-Lite), [Bria 3.2](https://huggingface.co/briaai/BRIA-3.2) and [bigASP 2.5](https://civitai.com/models/1789765?modelVersionId=2025412) - Redesigned [Video](https://vladmandic.github.io/sdnext-docs/Video) interface with support for general video models plus optimized [FramePack](https://vladmandic.github.io/sdnext-docs/FramePack) and [LTXVideo](https://vladmandic.github.io/sdnext-docs/LTX) support - Fully integrated nudity detection and optional censorship with [NudeNet](https://vladmandic.github.io/sdnext-docs/NudeNet) - New background replacement and relightning methods using **Latent Bridge Matching** and new **PixelArt** processing filter - Enhanced auto-detection of default sampler types/settings results in avoiding common mistakes - Additional **LLM/VLM** models available for captioning and prompt enhance - Number of workflow and general quality-of-life improvements, especially around **Styles**, **Detailer**, **Preview**, **Batch**, **Control** - Compute improvements - [Wiki](https://github.com/vladmandic/automatic/wiki) & [Docs](https://vladmandic.github.io/sdnext-docs/) updates, especially new end-to-end [Parameters](https://vladmandic.github.io/sdnext-docs/Parameters/) page In this release we finally break with legacy with the removal of the original [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/) codebase which has not been maintained for a while now This plus major cleanup of codebase and external dependencies resulted in ~55k LoC (*lines-of-code*) reduction and spread over [~750 files](https://github.com/vladmandic/sdnext/pull/4017) in ~200 commits! We also switched project license to [Apache-2.0](https://github.com/vladmandic/sdnext/blob/dev/LICENSE.txt) which means that SD.Next is now fully compatible with commercial and non-commercial use and redistribution regardless of modifications! And (*as always*) many bugfixes and improvements to existing features! For details, see [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) > [!NOTE] > We recommend clean install for this release due to sheer size of changes > Although upgrades and existing installations are tested and should work fine! ![Screenshot](https://github.com/user-attachments/assets/6f156154-0b0a-4be2-94f0-979e9f679501) [ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) ### Details for 2025-07-29 - **License** - SD.Next [license](https://github.com/vladmandic/sdnext/blob/dev/LICENSE.txt) switched from **aGPL-v3.0** to **Apache-v2.0** this means that SD.Next is now fully compatible with commercial and non-commercial use and redistribution regardless of modifications! - **Models** - [WanAI Wan 2.2](https://github.com/Wan-Video/Wan2.2) both 5B and A14B variants, for both T2V and I2V support go to: *video -> generic -> wan -> pick variant* optimized support with *VACE*, etc. will follow soon *caution* Wan2.2 on its own is ~68GB, but also includes optional second-stage for later low-noise processing which is absolutely massive at additional ~54GB you can enable second stage processing in *settings -> model options*, its disabled by default *note*: quantization and offloading are highly recommended regardless of first-stage only or both stages! - [WanAI Wan](https://wan.video/) T2V models for T2I workflows Wan is originally designed for *video* workflows, but now also be used for *text-to-image* workflows! supports *Wan-2.1 in 1.3B* and 14B variants and *Wan-2.2 in 5B and A14B* variants supports all standard features such as quantization, offloading, TAESD preview generation, LoRA support etc. can also load unet/transformer fine-tunes in safetensors format using UNET loader simply select in *networks -> models -> reference* *note* 1.3B model is a bit too small for good results and 14B is very large at 78GB even without second-stage so aggressive quantization and offloading are recommended - [FreePik F-Lite](https://huggingface.co/Freepik/F-Lite) in *7B, 10B and Texture* variants F-Lite is a 7B/10B model trained exclusively on copyright-safe and SFW content, trained on internal dataset comprising approximately 80 million copyright-safe images available via *networks -> models -> reference* - [Bria 3.2](https://huggingface.co/briaai/BRIA-3.2) Bria is a smaller 4B parameter model built entirely on licensed data and safe for commercial use *note*: this is a gated model, you need to [accept terms](https://huggingface.co/briaai/BRIA-3.2) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/) available via *networks -> models -> reference* - [bigASP 2.5](https://civitai.com/models/1789765) bigASP is an experimental SDXL finetune using Flow matching method load as usual, and leave sampler set to *Default* or you can use following samplers: *UniPC, DPM, DEIS, SA* required sampler settings: *prediction-method=flow-prediction*, *sigma-method=flowmatch* recommended sampler settings: *flow-shift=1.0* - [LBM: Latent Bridge Matching](https://github.com/gojasper/LBM) very fast automatic image background replacement methods with relightning! *simple*: automatic background replacement using [BiRefNet](https://github.com/ZhengPeng7/BiRefNet) *relighting*: automatic background replacement with reglighting so source image fits desired background with optional composite blending available in *img2img or control -> scripts* - add **FLUX.1-Kontext-Dev** inpaint workflow - add **FLUX.1-Kontext-Dev** **Nunchaku** support *note*: FLUX.1 Kontext is about 2-3x faster with Nunchaku vs standard execution! - support **FLUX.1** all-in-one safetensors - support for [Google Gemma 3n](https://huggingface.co/google/gemma-3n-E4B-it) E2B and E4B LLM/VLM models available in **prompt enhance** and process **captioning** - support for [HuggingFace SmolLM3](https://huggingface.co/HuggingFaceTB/SmolLM3-3B) 3B LLM model available in **prompt enhance** - add [fal AuraFlow 0.2](https://huggingface.co/fal/AuraFlow-v0.2) in addition to existing [fal AuraFlow 0.3](https://huggingface.co/fal/AuraFlow-v0.3) due to large differences in model behavior available via *networks -> models -> reference* - add integrated [NudeNet](https://vladmandic.github.io/sdnext-docs/NudeNet) as built-in functionality *note*: used to be available as a separate [extension](https://github.com/vladmandic/sd-extension-nudenet) - **Video** - redesigned **Video** interface - support for **Generic** video models includes support for many video models without specific per-model optimizations included: *Hunyuan, LTX, WAN, Mochi, Latte, Allegro, Cog* supports quantization, offloading, frame interpolation, etc. - support for optimized [FramePack](https://vladmandic.github.io/sdnext-docs/FramePack) with *t2i, i2i, flf2v* workflows LoRA support, prompt enhance, etc. now fully integrated instead of being a separate extension - support for optmized [LTXVideo](https://vladmandic.github.io/sdnext-docs/LTX) with *t2i, i2i, v2v* workflows optional native upsampling and video refine workflows LoRA support with different conditioning types such as Canny/Depth/Pose, etc. - support for post load quantization - **UI** - major update to modernui layout - add new Windows-like *Blocks* UI theme - redesign of the *Flat* UI theme - enhanced look&feel for *Gallery* tab with better search and collapsible sections, thanks to @CalamitousFelicitousness - **WIKI** - new [Parameters](https://vladmandic.github.io/sdnext-docs/Parameters/) page that lists and explains all generation parameters massive thanks to @CalamitousFelicitousness for bringing this to life! - updated *Models, Video, LTX, FramePack, Styles*, etc. - **Compute** - support for [SageAttention2++](https://github.com/thu-ml/SageAttention) provides 10-15% performance improvement over default SDPA for transformer-based models! enable in *settings -> compute settings -> sdp options* *note*: SD.Next will use either SageAttention v1/v2/v2++, depending which one is installed until authors provide pre-build wheels for v2++, you need to install it manually or SD.Next will auto-install v1 - support for `torch.compile` for LLM: captioning/prompt-enhannce - support for `torch.compile` with repeated-blocks reduces time-to-compile 5x without loss of performance! enable in *settings -> model compile -> repeated* *note*: torch.compile is not compatible with balanced offload - **Other** - **Styles** can now include both generation params and server settings see [Styles docs](https://vladmandic.github.io/sdnext-docs/Styles/) for details - **TAESD** is now default preview type since its the only one that supports most new models - support **TAESD** preview and remote VAE for **HunyuanDit** - support **TAESD** preview and remote VAE for **AuraFlow** - support **TAESD** preview for **WanAI** - SD.Next now starts with *locked* state preventing model loading until startup is complete - warn when modifying legacy settings that are no longer supported, but available for compatibilty - warn on incompatible sampler and automatically restore default sampler - **XYZ grid** can now work with control tab: if controlnet or processor are selected in xyz grid, they will overwrite settings from first unit in control tab, when using controlnet/processor selected in xyz grid, behavior is forced as control-only also freely selectable are control strength, start and end values - **Batch** warn on unprocessable images and skip operations on errors so that other images can still be processed - **Metadata** improved parsing and detect foreign metadata detect ComfyUI images detect InvokeAI images - **Detailer** add `expert` mode where list of detailer models can be converted to textbox for manual editing see [docs](https://vladmandic.github.io/sdnext-docs/Detailer/) for more information - **Detailer** add option to merge multiple results from each detailer model for example, hands model can result in two hands each being processed separately or both hands can be merged into one composite job - **Control** auto-update width/height on image upload - **Control** auto-determine image save path depending on operations performed - autodetect **V-prediction** models and override default sampler prediction type as needed - **SDNQ** - use inference context during quantization - use static compile - rename quantization type for text encoders `default` option to `Same as model` - **API** - add `/sdapi/v1/lock-checkpoint` endpoint that can be used to lock/unlock model changes if model is locked, it cannot be changed using normal load or unload methods - **Fixes** - allow theme type `None` to be set in config - installer dont cache installed state - fix Cosmos-Predict2 retrying TAESD download - better handle startup import errors - fix traceback width preventing copy&paste - fix ansi controle output from scripts/extensions - fix diffusers models non-unique hash - fix loading of manually downloaded diffuser models - fix api `/sdapi/v1/embeddings` endpoint - fix incorrect reporting of deleted and modified files - fix SD3.x loader and TAESD preview - fix xyz with control enabled - fix control order of image save operations - fix control batch-input processing - fix modules merge save model - fix torchvision bicubic upsample with ipex - fix instantir pipeline - fix prompt encoding if prompts within batch have different segment counts - fix detailer min/max size - fix loopback script - fix networks tags display - fix yolo refresh models - cleanup control infotext - allow upscaling with models that have implicit VAE processing - framepack improve offloading - improve prompt parser tokenizer loader - improve scripts error handling - improve infotext param parsing - improve extensions ui search - improve model type autodetection - improve model auth check for hf repos - improve Chroma prompt padding as per recommendations - lock directml torch to `torch-directml==0.2.4.dev240913` - lock directml transformers to `transformers==4.52.4` - improve install of `sentencepiece` tokenizer - add int8 matmul fallback for ipex with onednn qlinear - **Refactoring** *note*: none of the removals result in loss-of-functionality since all those features are already re-implemented goal here is to remove legacy code, code duplication and reduce code complexity - obsolete **original backend** - remove majority of legacy **a1111** codebase - remove legacy ldm codebase: `/repositories/ldm` - remove legacy blip codebase: `/repositories/blip` - remove legacy codeformer codebase: `/repositories/codeformer` - remove legacy clip patch model: `/models/karlo` - remove legacy model configs: `/configs/*.yaml` - remove legacy submodule: `/modules/k-diffusion` - remove legacy hypernetworks support: `/modules/hypernetworks` - remove legacy lora support: `/extensions-builtin/Lora` - remove legacy clip/blip interrogate module - remove modern-ui remove `only-original` vs `only-diffusers` code paths - refactor control processing and separate preprocessing and image save ops - refactor modernui layouts to rely on accordions more than individual controls - refactore pipeline apply/unapply optional components & features - split monolithic `shared.py` - cleanup `/modules`: move pipeline loaders to `/pipelines` root - cleanup `/modules`: move code folders used by pipelines to `/pipelines/` folder - cleanup `/modules`: move code folders used by scripts to `/scripts/ `); document.close(); } function restartReload() { document.body.style = 'background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center'; document.body.innerHTML = '

Server shutdown in progress...

'; authFetch(`${window.api}/progress?skip_current_image=true`) .then((res) => setTimeout(restartReload, 1000)) .catch((e) => setTimeout(monitorServerStatus, 500)); return []; } function updateInput(target) { const e = new Event('input', { bubbles: true }); Object.defineProperty(e, 'target', { value: target }); target.dispatchEvent(e); } let desiredCheckpointName = null; function selectCheckpoint(name) { desiredCheckpointName = name; const tabName = getENActiveTab(); const btnModel = gradioApp().getElementById(`${tabName}_extra_model`); const isRefiner = btnModel && btnModel.classList.contains('toolbutton-selected'); if (isRefiner) gradioApp().getElementById('change_refiner').click(); else gradioApp().getElementById('change_checkpoint').click(); log(`selectCheckpoint ${isRefiner ? 'refiner' : 'model'}: ${desiredCheckpointName}`); markSelectedCards([desiredCheckpointName], 'model'); } let desiredVAEName = null; function selectVAE(name) { desiredVAEName = name; gradioApp().getElementById('change_vae').click(); log(`selectVAE: ${desiredVAEName}`); markSelectedCards([desiredVAEName], 'vae'); } function selectReference(name) { log(`selectReference: ${name}`); desiredCheckpointName = name; gradioApp().getElementById('change_reference').click(); markSelectedCards([desiredCheckpointName], 'model'); } function currentImageResolutionimg2img(_a, _b, scaleBy) { const img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img'); return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]; } function currentImageResolutioncontrol(_a, _b, scaleBy) { const img = gradioApp().querySelector('#control-tab-input > div[style="display: block;"] img'); return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]; } function updateImg2imgResizeToTextAfterChangingImage() { const el = gradioApp().getElementById('img2img_update_resize_to'); if (el) setTimeout(() => gradioApp().getElementById('img2img_update_resize_to').click(), 500); return []; } function createThemeElement() { const el = document.createElement('img'); el.id = 'theme-preview'; el.className = 'theme-preview'; el.onclick = () => { el.style.display = 'none'; }; document.body.appendChild(el); return el; } function toggleCompact(val, old) { if (val === old) return; log('toggleCompact', val, old); if (val) { gradioApp().style.setProperty('--layout-gap', 'var(--spacing-md)'); gradioApp().querySelectorAll('input[type=range]').forEach((el) => el.classList.add('hidden')); gradioApp().querySelectorAll('div .form').forEach((el) => el.classList.add('form-compact')); gradioApp().querySelectorAll('.small-accordion .label-wrap').forEach((el) => el.classList.add('accordion-compact')); } else { gradioApp().style.setProperty('--layout-gap', 'var(--spacing-xxl)'); gradioApp().querySelectorAll('input[type=range]').forEach((el) => el.classList.remove('hidden')); gradioApp().querySelectorAll('div .form').forEach((el) => el.classList.remove('form-compact')); gradioApp().querySelectorAll('.small-accordion .label-wrap').forEach((el) => el.classList.remove('accordion-compact')); } } function previewTheme() { let name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('input')?.[0].value || ''; fetch(`${window.subpath}/file=data/themes.json`) .then((res) => { res.json() .then((themes) => { const theme = Array.isArray(themes) ? themes.find((t) => t.id === name) : null; if (theme) { window.open(theme.subdomain, '_blank'); } else { const el = document.getElementById('theme-preview') || createThemeElement(); el.style.display = el.style.display === 'block' ? 'none' : 'block'; name = name.replace('/', '-'); el.src = `/file=html/${name}.jpg`; } }) .catch((e) => error(`previewTheme: ${e}`)); }) .catch((e) => error(`previewTheme: ${e}`)); } async function browseFolder() { const f = await window.showDirectoryPicker(); if (f && f.kind === 'directory') return f.name; return null; } async function reconnectUI() { const gallery = gradioApp().getElementById('txt2img_gallery'); const task_id = localStorage.getItem('task'); const api_logo = Array.from(gradioApp().querySelectorAll('img')).filter((el) => el?.src?.endsWith('api-logo.svg')); if (api_logo.length > 0) api_logo[0].remove(); if (task_id) { debug('task check:', task_id); requestProgress(task_id, null, gallery, null, null, true); } uiLoaded = true; const sd_model = gradioApp().getElementById('setting_sd_model_checkpoint'); let loadingStarted = 0; let loadingMonitor = 0; const sd_model_callback = () => { const loading = sd_model.querySelector('.eta-bar'); if (!loading) { loadingStarted = 0; clearInterval(loadingMonitor); } else if (loadingStarted === 0) { loadingStarted = Date.now(); loadingMonitor = setInterval(() => { const elapsed = Date.now() - loadingStarted; if (elapsed > 3000 && loading) loading.style.display = 'none'; }, 5000); } }; const sd_model_observer = new MutationObserver(sd_model_callback); sd_model_observer.observe(sd_model, { attributes: true, childList: true, subtree: true }); log('reconnectUI'); monitorConnection(); } ================================================ FILE: javascript/uiConfig.js ================================================ function uiOpenSubmenus() { const accordions = Array.from(gradioApp().querySelectorAll('.gradio-accordion')); const states = {}; accordions.forEach((el) => { const name = el.querySelector('.label-wrap > span:not(.icon)').innerText.trim(); const children = Array.from(el.childNodes); const open = children.filter((c) => c.style?.display === 'block'); if (states[name] === undefined) states[name] = open.length > 0; }); return states; } async function getUIDefaults() { const btn = gradioApp().getElementById('ui_defaults_view'); if (!btn) return; const intersectionObserver = new IntersectionObserver((entries) => { if (entries[0].intersectionRatio <= 0) { /* Pass */ } if (entries[0].intersectionRatio > 0) btn.click(); }); intersectionObserver.observe(btn); // monitor visibility of tab } ================================================ FILE: launch.py ================================================ #!/usr/bin/env python import os import sys import time import shlex import subprocess from functools import lru_cache import installer debug_install = installer.log.debug if os.environ.get('SD_INSTALL_DEBUG', None) is not None else lambda *args, **kwargs: None commandline_args = os.environ.get('COMMANDLINE_ARGS', "") sys.argv += shlex.split(commandline_args) args = None parser = None script_path = None extensions_dir = None git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") stored_commit_hash = None dir_repos = "repositories" python = sys.executable # used by some extensions to run python skip_install = False # parsed by some extensions try: from modules.timer import launch, init rec = launch.record init_summary = init.summary except Exception: rec = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment init_summary = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment def init_args(): global parser, args # pylint: disable=global-statement import modules.cmd_args parser = modules.cmd_args.parser installer.add_args(parser) args, _ = parser.parse_known_args() rec('args') def init_paths(): global script_path, extensions_dir # pylint: disable=global-statement import modules.paths script_path = modules.paths.script_path extensions_dir = modules.paths.extensions_dir sys.path.insert(0, script_path) rec('paths') def get_custom_args(): custom = {} for arg in vars(args): default = parser.get_default(arg) current = getattr(args, arg) if current != default: custom[arg] = getattr(args, arg) installer.log.info(f'Command line args: {sys.argv[1:]} {installer.print_dict(custom)}') if os.environ.get('SD_ENV_DEBUG', None) is not None: env = os.environ.copy() if 'PATH' in env: del env['PATH'] if 'PS1' in env: del env['PS1'] installer.log.trace(f'Environment: {installer.print_dict(env)}') env = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('SD_')] ld = [f'{k}={v}' for k, v in os.environ.items() if k.startswith('LD_')] installer.log.debug(f'Flags: sd={env} ld={ld}') rec('args') @lru_cache() def commit_hash(): # compatbility function global stored_commit_hash # pylint: disable=global-statement if stored_commit_hash is not None: return stored_commit_hash try: stored_commit_hash = run(f"{git} rev-parse HEAD").strip() except Exception: stored_commit_hash = "" rec('commit') return stored_commit_hash @lru_cache() def run(command, desc=None, errdesc=None, custom_env=None, live=False): # compatbility function if desc is not None: installer.log.info(desc) if live: result = subprocess.run(command, check=False, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: raise RuntimeError(f"""{errdesc or 'Error running command'} Command: {command} Error code: {result.returncode}""") return '' result = subprocess.run(command, stdout=subprocess.PIPE, check=False, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: raise RuntimeError(f"""{errdesc or 'Error running command'}: {command} code: {result.returncode} {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''} {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''} """) return result.stdout.decode(encoding="utf8", errors="ignore") def check_run(command): # compatbility function result = subprocess.run(command, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) return result.returncode == 0 @lru_cache() def is_installed(pkg): # compatbility function return installer.installed(pkg) @lru_cache() def repo_dir(name): # compatbility function return os.path.join(script_path, dir_repos, name) @lru_cache() def run_python(code, desc=None, errdesc=None): # compatbility function return run(f'"{sys.executable}" -c "{code}"', desc, errdesc) @lru_cache() def run_pip(pkg, desc=None): # compatbility function forbidden = ['onnxruntime', 'opencv-python'] if desc is None: desc = pkg for f in forbidden: if f in pkg: debug_install(f'Blocked package installation: package={f}') return True index_url_line = f' --index-url {index_url}' if index_url != '' else '' return run(f'"{sys.executable}" -m pip {pkg} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") @lru_cache() def check_run_python(code): # compatbility function return check_run(f'"{sys.executable}" -c "{code}"') def git_clone(url, tgt, _name, commithash=None): # compatbility function installer.clone(url, tgt, commithash) def run_extension_installer(ext_dir): # compatbility function installer.run_extension_installer(ext_dir) def get_memory_stats(detailed:bool=False): from modules.memstats import ram_stats, memory_stats if not detailed: res = ram_stats() return f'{res["used"]}/{res["total"]}' else: res = memory_stats() return res def clean_server(): t0 = time.time() import gc modules_loaded = sorted(sys.modules.keys()) modules_to_remove = ['webui', 'modules', 'scripts', 'gradio', 'onnx', 'torch', 'pytorch', 'lightning', 'tensor', 'diffusers', 'transformers', 'tokenize', 'safetensors', 'gguf', 'accelerate', 'peft', 'triton', 'huggingface', 'PIL', 'cv2', 'timm', 'numpy', 'scipy', 'sympy', 'sklearn', 'skimage', 'sqlalchemy', 'flash_attn', 'bitsandbytes', 'xformers', 'matplotlib', 'optimum', 'pandas', 'pi', 'git', 're', 'altair', 'framepack', 'nudenet', 'agent_scheduler', 'basicsr', 'gfpgan', 'war', 'fastapi', 'urllib', 'uvicorn', 'web', 'http', 'google', 'starlette', 'socket'] removed_removed = [] for module_loaded in modules_loaded: for module_to_remove in modules_to_remove: if module_loaded.startswith(module_to_remove): try: del sys.modules[module_loaded] removed_removed.append(module_loaded) except Exception: pass collected = gc.collect() # python gc modules_cleaned = sorted(sys.modules.keys()) modules_keys = [m.split('.')[0] for m in modules_cleaned if not m.startswith('_')] modules_sorted = {} for module_key in modules_keys: modules_sorted[module_key] = len([m for m in modules_cleaned if m.startswith(module_key)]) installer.log.trace(f'Server modules: {modules_sorted}') t1 = time.time() installer.log.trace(f'Server modules: total={len(modules_loaded)} unloaded={len(removed_removed)} remaining={len(modules_cleaned)} gc={collected} time={t1-t0:.2f}') def start_server(immediate=True, server=None): if args.profile: import cProfile pr = cProfile.Profile() pr.enable() import gc import importlib.util collected = 0 if server is not None: server = None collected = gc.collect() if not immediate: time.sleep(3) if collected > 0: installer.log.debug(f'Memory: {get_memory_stats()} collected={collected}') module_spec = importlib.util.spec_from_file_location('webui', 'webui.py') server = importlib.util.module_from_spec(module_spec) installer.log.debug(f'Starting module: {server}') module_spec.loader.exec_module(server) uvicorn = None if args.test: installer.log.info("Test only") installer.log.critical('Logging: level=critical') installer.log.error('Logging: level=error') installer.log.warning('Logging: level=warning') installer.log.info('Logging: level=info') installer.log.debug('Logging: level=debug') installer.log.trace('Logging: level=trace') server.wants_restart = False else: uvicorn = server.webui(restart=not immediate) if args.profile: pr.disable() installer.print_profile(pr, 'WebUI') rec('server') return uvicorn, server def main(): global args # pylint: disable=global-statement installer.ensure_base_requirements() init_args() # setup argparser and default folders installer.args = args installer.setup_logging() installer.log.info('Starting SD.Next') installer.get_logfile() try: sys.excepthook = installer.custom_excepthook except Exception: pass installer.read_options() if args.skip_all: args.quick = True installer.check_python() if args.reset: installer.git_reset() if args.skip_git or args.skip_all: installer.log.info('Skipping GIT operations') installer.check_version() installer.log.info(f'Platform: {installer.print_dict(installer.get_platform())}') installer.check_venv() installer.log.info(f'Args: {sys.argv[1:]}') if not args.skip_env or args.skip_all: installer.set_environment() if args.uv: installer.install("uv", "uv") installer.install_gradio() installer.check_torch() installer.check_onnx() installer.check_transformers() installer.check_diffusers() installer.check_modified_files() if args.test: installer.log.info('Startup: test mode') installer.quick_allowed = False if args.reinstall: installer.log.info('Startup: force reinstall of all packages') installer.quick_allowed = False if args.skip_all: installer.log.info('Startup: skip all') installer.quick_allowed = True init_paths() else: installer.install_requirements() installer.install_packages() if installer.check_timestamp(): installer.log.info('Startup: quick launch') init_paths() installer.check_extensions() else: installer.log.info('Startup: standard') installer.install_submodules() init_paths() installer.install_extensions() installer.install_requirements() # redo requirements since extensions may change them installer.update_wiki() if len(installer.errors) == 0: installer.log.debug(f'Setup complete without errors: {round(time.time())}') else: installer.log.warning(f'Setup complete with errors: {installer.errors}') installer.log.warning(f'See log file for more details: {installer.log_file}') installer.extensions_preload(parser) # adds additional args from extensions args = installer.parse_args(parser) installer.log.info(f'Installer time: {init_summary()}') get_custom_args() uv, instance = start_server(immediate=True, server=None) if installer.restart_required: installer.log.warning('Restart is recommended due to packages updates...') t_server = time.time() t_monitor = time.time() while True: try: alive = uv.thread.is_alive() requests = uv.server_state.total_requests if hasattr(uv, 'server_state') else 0 except Exception: alive = False requests = 0 t_current = time.time() if float(args.status) > 0 and (t_current - t_server) > float(args.status): s = instance.state.status() if (s.timestamp is None) or (s.step == 0): # dont spam during active job installer.log.trace(f'Server: alive={alive} requests={requests} memory={get_memory_stats()} {s}') t_server = t_current if float(args.monitor) > 0 and t_current - t_monitor > float(args.monitor): installer.log.trace(f'Monitor: {get_memory_stats(detailed=True)}') t_monitor = t_current if not alive: if uv is not None and uv.wants_restart: clean_server() installer.log.info('Server restarting...') # uv, instance = start_server(immediate=False, server=instance) os.execv(sys.executable, ['python'] + sys.argv) else: installer.log.info('Exiting...') break time.sleep(1.0) if __name__ == "__main__": main() ================================================ FILE: modules/apg/__init__.py ================================================ # copied from paper: import torch import diffusers from .pipeline_stable_diffision_xl_apg import StableDiffusionXLPipelineAPG from .pipeline_stable_cascade_prior_apg import StableCascadePriorPipelineAPG from .pipeline_stable_diffusion_apg import StableDiffusionPipelineAPG class MomentumBuffer: def __init__(self, momentum_val: float): self.momentum = momentum_val self.running_average = 0 def update(self, update_value: torch.Tensor): new_average = self.momentum * self.running_average self.running_average = update_value + new_average eta = 0 momentum = 0 threshold = 0 buffer: MomentumBuffer = None orig_pipe: diffusers.DiffusionPipeline = None def project( v0: torch.Tensor, # [B, C, H, W] v1: torch.Tensor, # [B, C, H, W] ): device = v0.device dtype = v0.dtype if device.type == "xpu": v0, v1 = v0.to("cpu"), v1.to("cpu") v0, v1 = v0.double(), v1.double() v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 v0_orthogonal = v0 - v0_parallel return v0_parallel.to(device, dtype=dtype), v0_orthogonal.to(device, dtype=dtype) def normalized_guidance( pred_cond: torch.Tensor, # [B, C, H, W] pred_uncond: torch.Tensor, # [B, C, H, W] guidance_scale: float, ): diff = pred_cond - pred_uncond if buffer is not None: buffer.update(diff) diff = buffer.running_average if threshold > 0: ones = torch.ones_like(diff) diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True) scale_factor = torch.minimum(ones, threshold / diff_norm) diff = diff * scale_factor diff_parallel, diff_orthogonal = project(diff, pred_cond) normalized_update = diff_orthogonal + eta * diff_parallel pred_guided = pred_cond + (guidance_scale - 1) * normalized_update return pred_guided ================================================ FILE: modules/apg/pipeline_stable_cascade_prior_apg.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from math import ceil from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.models import StableCascadeUNet from diffusers.schedulers import DDPMWuerstchenScheduler from diffusers.utils import BaseOutput, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from modules import apg logger = logging.get_logger(__name__) # pylint: disable=invalid-name DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableCascadePriorPipeline >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) ``` """ @dataclass class StableCascadePriorPipelineOutput(BaseOutput): """ Output class for WuerstchenPriorPipeline. Args: image_embeddings (`torch.Tensor` or `np.ndarray`) Prior image embeddings for text prompt prompt_embeds (`torch.Tensor`): Text embeddings for the prompt. negative_prompt_embeds (`torch.Tensor`): Text embeddings for the negative prompt. """ image_embeddings: Union[torch.Tensor, np.ndarray] prompt_embeds: Union[torch.Tensor, np.ndarray] prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] negative_prompt_embeds: Union[torch.Tensor, np.ndarray] negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] class StableCascadePriorPipelineAPG(DiffusionPipeline): """ Pipeline for generating image prior for Stable Cascade. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`StableCascadeUNet`]): The Stable Cascade prior to approximate the image embedding from the text and/or image embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). feature_extractor ([`~transformers.CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. resolution_multiple ('float', *optional*, defaults to 42.67): Default resolution for multiple images generated. """ unet_name = "prior" text_encoder_name = "text_encoder" model_cpu_offload_seq = "image_encoder->text_encoder->prior" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, prior: StableCascadeUNet, scheduler: DDPMWuerstchenScheduler, resolution_multiple: float = 42.67, feature_extractor: Optional[CLIPImageProcessor] = None, image_encoder: Optional[CLIPVisionModelWithProjection] = None, ) -> None: super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, image_encoder=image_encoder, feature_extractor=feature_extractor, prior=prior, scheduler=scheduler, ) self.register_to_config(resolution_multiple=resolution_multiple) def prepare_latents( self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler ): latent_shape = ( num_images_per_prompt * batch_size, self.prior.config.in_channels, ceil(height / self.config.resolution_multiple), ceil(width / self.config.resolution_multiple), ) if latents is None: latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != latent_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def encode_prompt( self, device, batch_size, num_images_per_prompt, do_classifier_free_guidance, prompt=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, ): if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] attention_mask = attention_mask[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True ) prompt_embeds = text_encoder_output.hidden_states[-1] if prompt_embeds_pooled is None: prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds_text_encoder_output = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device), output_hidden_states=True, ) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) seq_len = negative_prompt_embeds_pooled.shape[1] negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( dtype=self.text_encoder.dtype, device=device ) negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( batch_size * num_images_per_prompt, seq_len, -1 ) # done duplicates return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt): image_embeds = [] for image in images: image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embed = self.image_encoder(image).image_embeds.unsqueeze(1) image_embeds.append(image_embed) image_embeds = torch.cat(image_embeds, dim=1) image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) negative_image_embeds = torch.zeros_like(image_embeds) return image_embeds, negative_image_embeds def check_inputs( self, prompt, images=None, image_embeds=None, negative_prompt=None, prompt_embeds=None, prompt_embeds_pooled=None, negative_prompt_embeds=None, negative_prompt_embeds_pooled=None, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and prompt_embeds_pooled is None: raise ValueError( "If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" ) if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" ) if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: raise ValueError( "`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed" f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !=" f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}." ) if image_embeds is not None and images is not None: raise ValueError( f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to" " only forward one of the two." ) if images: for i, image in enumerate(images): if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): raise TypeError( f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got" f"{type(image)} for image number {i}." ) @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps def get_timestep_ratio_conditioning(self, t, alphas_cumprod): s = torch.tensor([0.008]) clamp_range = [0, 1] min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 var = alphas_cumprod[t] var = var.clamp(*clamp_range) s, min_var = s.to(var.device), min_var.to(var.device) ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s return ratio @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, height: int = 1024, width: int = 1024, num_inference_steps: int = 20, timesteps: List[float] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pt", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to 1024): The height in pixels of the generated image. width (`int`, *optional*, defaults to 1024): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 60): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 8.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input argument. image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. If not provided, image embeddings will be generated from `image` input argument if existing. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ # 0. Define commonly used variables device = self._execution_device dtype = next(self.prior.parameters()).dtype self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, images=images, image_embeds=image_embeds, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) # 2. Encode caption + images ( prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled, ) = self.encode_prompt( prompt=prompt, device=device, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, ) if images is not None: image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image( images=images, device=device, dtype=dtype, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, ) elif image_embeds is not None: image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) else: image_embeds_pooled = torch.zeros( batch_size * num_images_per_prompt, 1, self.prior.config.clip_image_in_channels, device=device, dtype=dtype, ) uncond_image_embeds_pooled = torch.zeros( batch_size * num_images_per_prompt, 1, self.prior.config.clip_image_in_channels, device=device, dtype=dtype, ) if self.do_classifier_free_guidance: image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0) else: image_embeds = image_embeds_pooled # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_encoder_hidden_states = ( torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds ) text_encoder_pooled = ( torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) if negative_prompt_embeds is not None else prompt_embeds_pooled ) # 4. Prepare and set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents( batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler ) if isinstance(self.scheduler, DDPMWuerstchenScheduler): timesteps = timesteps[:-1] else: if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: self.scheduler.config.clip_sample = False # disample sample clipping logger.warning(" set `clip_sample` to be False") # 6. Run denoising loop if hasattr(self.scheduler, "betas"): alphas = 1.0 - self.scheduler.betas alphas_cumprod = torch.cumprod(alphas, dim=0) else: alphas_cumprod = [] self._num_timesteps = len(timesteps) for i, t in enumerate(self.progress_bar(timesteps)): if not isinstance(self.scheduler, DDPMWuerstchenScheduler): if len(alphas_cumprod) > 0: timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) else: timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) else: timestep_ratio = t.expand(latents.size(0)).to(dtype) # 7. Denoise image embeddings predicted_image_embedding = self.prior( sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, clip_text_pooled=text_encoder_pooled, clip_text=text_encoder_hidden_states, clip_img=image_embeds, return_dict=False, )[0] # 8. Check for classifier free guidance and apply it if self.do_classifier_free_guidance: predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding = apg.normalized_guidance(predicted_image_embedding_text, predicted_image_embedding_uncond, self.guidance_scale) # 9. Renoise latents to next timestep if not isinstance(self.scheduler, DDPMWuerstchenScheduler): timestep_ratio = t latents = self.scheduler.step( model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # Offload all models self.maybe_free_model_hooks() if output_type == "np": latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work negative_prompt_embeds = ( negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None ) # float() as bfloat16-> numpy doesnt work if not return_dict: return ( latents, prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled, ) return StableCascadePriorPipelineOutput( image_embeddings=latents, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, ) ================================================ FILE: modules/apg/pipeline_stable_diffision_xl_apg.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.models.attention_processor import Attention from modules import apg if is_invisible_watermark_available(): from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLPipeline >>> pipe = StableDiffusionXLPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLPipelineAPG( DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, height, width, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = apg.normalized_guidance(noise_pred_cond, noise_pred_uncond, guidance_scale=self.guidance_scale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: modules/apg/pipeline_stable_diffusion_apg.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.configuration_utils import FrozenDict from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from modules import apg if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionPipelineAPG( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) else None ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = apg.normalized_guidance(noise_pred_cond, noise_pred_uncond, guidance_scale=self.guidance_scale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: modules/api/api.py ================================================ from typing import List, Optional from threading import Lock from secrets import compare_digest from fastapi import FastAPI, APIRouter, Depends, Request from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.exceptions import HTTPException from modules import errors, shared from modules.api import models, endpoints, script, helpers, server, generate, process, control, docs, gpu errors.install() class Api: def __init__(self, app: FastAPI, queue_lock: Lock): self.credentials = {} if shared.cmd_opts.auth: for auth in shared.cmd_opts.auth.split(","): user, password = auth.split(":") self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip() if shared.cmd_opts.auth_file: with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file: for line in file.readlines(): user, password = line.split(":") self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip() self.router = APIRouter() if shared.cmd_opts.docs: docs.create_docs(app) docs.create_redocs(app) self.app = app self.queue_lock = queue_lock self.generate = generate.APIGenerate(queue_lock) self.process = process.APIProcess(queue_lock) self.control = control.APIControl(queue_lock) # compatibility api self.text2imgapi = self.generate.post_text2img self.img2imgapi = self.generate.post_img2img def register(self): # fetch js/css self.add_api_route("/js", server.get_js, methods=["GET"], auth=False) # server api self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str) self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"]) self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"]) self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"]) self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus) self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"]) self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress) self.add_api_route("/sdapi/v1/history", server.get_history, methods=["GET"], response_model=list[models.ResHistory]) self.add_api_route("/sdapi/v1/interrupt", server.post_interrupt, methods=["POST"]) self.add_api_route("/sdapi/v1/skip", server.post_skip, methods=["POST"]) self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"]) self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory) self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=List[models.ResGPU]) # core api using locking self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img) self.add_api_route("/sdapi/v1/img2img", self.generate.post_img2img, methods=["POST"], response_model=models.ResImg2Img) self.add_api_route("/sdapi/v1/control", self.control.post_control, methods=["POST"], response_model=control.ResControl) self.add_api_route("/sdapi/v1/extra-single-image", self.process.extras_single_image_api, methods=["POST"], response_model=models.ResProcessImage) self.add_api_route("/sdapi/v1/extra-batch-images", self.process.extras_batch_images_api, methods=["POST"], response_model=models.ResProcessBatch) self.add_api_route("/sdapi/v1/preprocess", self.process.post_preprocess, methods=["POST"]) self.add_api_route("/sdapi/v1/mask", self.process.post_mask, methods=["POST"]) self.add_api_route("/sdapi/v1/detect", self.process.post_detect, methods=["POST"]) self.add_api_route("/sdapi/v1/prompt-enhance", self.process.post_prompt_enhance, methods=["POST"], response_model=models.ResPromptEnhance) # api dealing with optional scripts self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts) self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript]) # enumerator api self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess]) self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask) self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler]) self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler]) self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel]) self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/face-restorers", endpoints.get_restorers, methods=["GET"], response_model=List[models.ItemDetailer]) self.add_api_route("/sdapi/v1/detailers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer]) self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle]) self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings) self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae]) self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension]) self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork]) # functional api self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo) self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"]) self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"]) self.add_api_route("/sdapi/v1/checkpoint", endpoints.get_checkpoint, methods=["GET"]) self.add_api_route("/sdapi/v1/checkpoint", endpoints.set_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int) self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"]) self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict) # lora api from modules.api import loras loras.register_api() # gallery api from modules.api import gallery gallery.register_api(self.app) # nudenet api from modules.api import nudenet nudenet.register_api() # xyz-grid api from modules.api import xyz_grid xyz_grid.register_api() # civitai api from modules.civitai import api_civitai api_civitai.register_api() def add_api_route(self, path: str, fn, auth: bool = True, **kwargs): if auth and self.credentials: deps = list(kwargs.get('dependencies', [])) deps.append(Depends(self.auth)) kwargs['dependencies'] = deps if shared.opts.subpath is not None and len(shared.opts.subpath) > 0: self.app.add_api_route(f'{shared.opts.subpath}{path}', endpoint=fn, **kwargs) self.app.add_api_route(path, endpoint=fn, **kwargs) def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())): if not self.credentials: return True if credentials.username in self.credentials: if compare_digest(credentials.password, self.credentials[credentials.username]): return True if hasattr(self.app, 'tokens') and (self.app.tokens is not None): if credentials.password in self.app.tokens.keys(): return True shared.log.error(f'API authentication: user="{credentials.username}"') raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"}) def get_session_start(self, req: Request, agent: Optional[str] = None): token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure") user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}') return {} def launch(self): config = { "listen": shared.cmd_opts.listen, "port": shared.cmd_opts.port, "keyfile": shared.cmd_opts.tls_keyfile, "certfile": shared.cmd_opts.tls_certfile, "loop": "auto", # auto, asyncio, uvloop "http": "auto", # auto, h11, httptools } from modules.server import UvicornServer http_server = UvicornServer(self.app, **config) # from modules.server import HypercornServer # server = HypercornServer(self.app, **config) http_server.start() shared.log.info(f'API server: Uvicorn options={config}') return http_server # compatibility items decode_base64_to_image = helpers.decode_base64_to_image encode_pil_to_base64 = helpers.encode_pil_to_base64 validate_sampler_name = helpers.validate_sampler_name ================================================ FILE: modules/api/control.py ================================================ from typing import Optional, List from threading import Lock from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from modules import errors, shared, processing_helpers from modules.api import models, helpers from modules.control import run errors.install() class ItemControl(BaseModel): process: str = Field(title="Preprocessor", default="", description="") model: str = Field(title="Control Model", default="", description="") strength: float = Field(title="Control model strength", default=1.0, description="") start: float = Field(title="Control model start", default=0.0, description="") end: float = Field(title="Control model end", default=1.0, description="") override: str = Field(title="Override image", default=None, description="") class ItemXYZ(BaseModel): x_type: str = Field(title="X axis values", default='') x_values: str = Field(title="X axis values", default='') y_type: str = Field(title="Y axis values", default='') y_values: str = Field(title="Y axis values", default='') z_type: str = Field(title="Z axis values", default='') z_values: str = Field(title="Z axis values", default='') draw_legend: bool = Field(title="Draw legend", default=True) include_grid: bool = Field(title="Include grid", default=True) include_subgrids: bool = Field(title="Include subgrids", default=False) include_images: bool = Field(title="Include images", default=False) include_time: bool = Field(title="Include time", default=False) include_text: bool = Field(title="Include text", default=False) ReqControl = models.create_model_from_signature( func = run.control_run, model_name = "StableDiffusionProcessingControl", additional_fields = [ {"key": "sampler_name", "type": str, "default": "Default"}, {"key": "script_name", "type": Optional[str], "default": None}, {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "ip_adapter", "type": Optional[List[models.ItemIPAdapter]], "default": None, "exclude": True}, {"key": "face", "type": Optional[models.ItemFace], "default": None, "exclude": True}, {"key": "control", "type": Optional[List[ItemControl]], "default": [], "exclude": True}, {"key": "xyz", "type": Optional[ItemXYZ], "default": None, "exclude": True}, # {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True}, ] ) if not hasattr(ReqControl, "__config__"): ReqControl.__config__ = models.DummyConfig class ResControl(BaseModel): images: List[str] = Field(default=None, title="Images", description="") processed: List[str] = Field(default=None, title="Processed", description="") params: dict = Field(default={}, title="Settings", description="") info: str = Field(default="", title="Info", description="") class APIControl(): def __init__(self, queue_lock: Lock): self.queue_lock = queue_lock self.default_script_arg = [] self.units = [] def sanitize_args(self, args: dict): args = vars(args) args.pop('sampler_name', None) args.pop('alwayson_scripts', None) args.pop('face', None) args.pop('face_id', None) args.pop('ip_adapter', None) args.pop('save_images', None) args['override_script_name'] = args.pop('script_name', None) args['override_script_args'] = args.pop('script_args', None) return args def sanitize_b64(self, request): def sanitize_str(args: list): for idx in range(0, len(args)): if isinstance(args[idx], str) and len(args[idx]) >= 1000: args[idx] = f"" if hasattr(request, "alwayson_scripts") and request.alwayson_scripts: for script_name in request.alwayson_scripts.keys(): script_obj = request.alwayson_scripts[script_name] if script_obj and "args" in script_obj and script_obj["args"]: sanitize_str(script_obj["args"]) if hasattr(request, "script_args") and request.script_args: sanitize_str(request.script_args) if hasattr(request, 'override_script_args') and request.override_script_args: request.pop('override_script_args', None) def prepare_face_module(self, req): if hasattr(req, "face") and req.face and not req.script_name and (not req.alwayson_scripts or "face" not in req.alwayson_scripts.keys()): req.script_name = "face" req.script_args = [ req.face.mode, req.face.source_images, req.face.ip_model, req.face.ip_override_sampler, req.face.ip_cache_model, req.face.ip_strength, req.face.ip_structure, req.face.id_strength, req.face.id_conditioning, req.face.id_cache, req.face.pm_trigger, req.face.pm_strength, req.face.pm_start, req.face.fs_cache ] del req.face def prepare_xyz_grid(self, req): if hasattr(req, "xyz") and req.xyz: req.script_name = "xyz grid" req.script_args = [ req.xyz.x_type, req.xyz.x_values, '', req.xyz.y_type, req.xyz.y_values, '', req.xyz.z_type, req.xyz.z_values, '', False, # csv_mode req.xyz.draw_legend, False, # no_fixed_seeds req.xyz.include_grid, req.xyz.include_subgrids, req.xyz.include_images, req.xyz.include_time, req.xyz.include_text, ] del req.xyz def prepare_ip_adapter(self, request): if hasattr(request, "ip_adapter") and request.ip_adapter: args = { 'ip_adapter_names': [], 'ip_adapter_scales': [], 'ip_adapter_crops': [], 'ip_adapter_starts': [], 'ip_adapter_ends': [], 'ip_adapter_images': [], 'ip_adapter_masks': [] } for ipadapter in request.ip_adapter: if not ipadapter.images or len(ipadapter.images) == 0: continue args['ip_adapter_names'].append(ipadapter.adapter) args['ip_adapter_scales'].append(ipadapter.scale) args['ip_adapter_starts'].append(ipadapter.start) args['ip_adapter_ends'].append(ipadapter.end) args['ip_adapter_crops'].append(ipadapter.crop) args['ip_adapter_images'].append([helpers.decode_base64_to_image(x) for x in ipadapter.images]) if ipadapter.masks: args['ip_adapter_masks'].append([helpers.decode_base64_to_image(x) for x in ipadapter.masks]) del request.ip_adapter return args else: return {} def prepare_control(self, req): from modules.control.unit import Unit, unit_types req.units = [] if req.unit_type is None: req.unit_type = 'controlnet' if req.unit_type not in unit_types: shared.log.error(f'Control uknown unit type: type={req.unit_type} available={unit_types}') return for i in range(len(req.control)): u = req.control[i] if (len(self.units) > i) and (self.units[i].process_id == u.process) and (self.units[i].model_id == u.model): unit = self.units[i] unit.enabled = True unit.strength = u.strength unit.start = u.start unit.end = u.end else: unit = Unit( enabled = True, unit_type = req.unit_type, model_id = u.model, process_id = u.process, strength = u.strength, start = u.start, end = u.end, ) if u.override is not None: unit.override = helpers.decode_base64_to_image(u.override) req.units.append(unit) self.units = req.units del req.control def post_control(self, req: ReqControl): requested = req.control self.prepare_face_module(req) self.prepare_control(req) self.prepare_xyz_grid(req) # prepare scripts # prepare args args = req.copy(update={ # Override __init__ params "sampler_index": processing_helpers.get_sampler_index(req.sampler_name), "is_generator": True, "inputs": [helpers.decode_base64_to_image(x) for x in req.inputs] if req.inputs else None, "inits": [helpers.decode_base64_to_image(x) for x in req.inits] if req.inits else None, "mask": helpers.decode_base64_to_image(req.mask) if req.mask else None, }) args = self.sanitize_args(args) send_images = args.pop('send_images', True) # run with self.queue_lock: jobid = shared.state.begin('API-CTL', api=True) output_images = [] output_processed = [] output_info = '' run.control_set({ 'do_not_save_grid': not req.save_images, 'do_not_save_samples': not req.save_images, **self.prepare_ip_adapter(req), }) run.control_set(getattr(req, "extra", {})) # run res = run.control_run(**args) for item in res: if len(item) > 0 and (isinstance(item[0], list) or item[0] is None): # output_images output_images += item[0] if item[0] is not None else [] output_processed += [item[1]] if item[1] is not None else [] output_info += item[2] if len(item) > 2 and item[2] is not None else '' elif isinstance(item, str): output_info += item else: pass shared.state.end(jobid) # return b64images = list(map(helpers.encode_pil_to_base64, output_images)) if send_images else [] b64processed = list(map(helpers.encode_pil_to_base64, output_processed)) if send_images else [] self.sanitize_b64(req) req.units = requested return ResControl(images=b64images, processed=b64processed, params=vars(req), info=output_info) ================================================ FILE: modules/api/docs.py ================================================ import json from starlette.responses import HTMLResponse from fastapi import FastAPI from fastapi.openapi.docs import get_redoc_html, swagger_ui_default_parameters from fastapi.encoders import jsonable_encoder def get_swagger_ui_html(*, openapi_url: str, title: str, swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js", swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css", swagger_extra_css_url: str = None, swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", oauth2_redirect_url: str = None, init_oauth: dict = None, swagger_ui_parameters: dict = None, ) -> HTMLResponse: current_swagger_ui_parameters = swagger_ui_default_parameters.copy() if swagger_ui_parameters: current_swagger_ui_parameters.update(swagger_ui_parameters) html = f""" {title}
""" return HTMLResponse(html) def create_docs(app: FastAPI): swagger_ui_parameters = { "displayOperationId": True, "layout": "BaseLayout", "showExtensions": True, "showCommonExtensions": True, "deepLinking": False, "dom_id": "#swagger-ui", } @app.get("/docs", include_in_schema=True) async def custom_swagger_html(): res = get_swagger_ui_html( title=f'{app.title}: Swagger UI', openapi_url=app.openapi_url, swagger_favicon_url='/file=html/favicon.svg', swagger_css_url='/file=html/swagger.css', swagger_ui_parameters=swagger_ui_parameters, # swagger_extra_css_url='file=html/swagger.css', ) # res = inject_css(html.content, 'html/swagger.css') return res def create_redocs(app: FastAPI): @app.get("/redocs", include_in_schema=True) async def custom_redoc_html(): res = get_redoc_html( title=f'{app.title}: ReDoc', openapi_url=app.openapi_url, redoc_favicon_url='/file=html/favicon.svg', ) return res ================================================ FILE: modules/api/endpoints.py ================================================ from typing import Optional from fastapi.exceptions import HTTPException from modules import shared from modules.api import models, helpers def get_samplers(): from modules import sd_samplers_diffusers all_samplers = [] for k, v in sd_samplers_diffusers.config.items(): if k in ['All', 'Default', 'Res4Lyf']: continue all_samplers.append({ 'name': k, 'options': v, }) return all_samplers def get_sampler(): if not shared.sd_loaded or shared.sd_model is None: return {} if hasattr(shared.sd_model, 'scheduler'): scheduler = shared.sd_model.scheduler config = {k: v for k, v in scheduler.config.items() if not k.startswith('_')} return { 'name': scheduler.__class__.__name__, 'options': config } return {} def get_sd_vaes(): from modules.sd_vae import vae_dict return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] def get_upscalers(): return [{"name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, "model_url": None, "scale": upscaler.scale} for upscaler in shared.sd_upscalers] def get_sd_models(): from modules import sd_checkpoint checkpoints = [] for v in sd_checkpoint.checkpoints_list.values(): model = models.ItemModel(title=v.title, model_name=v.name, filename=v.filename, type=v.type, hash=v.shorthash, sha256=v.sha256, config=None) checkpoints.append(model) return checkpoints def get_controlnets(model_type: Optional[str] = None): from modules.control.units.controlnet import api_list_models return api_list_models(model_type) def get_restorers(): return [{"name":x.name(), "path": getattr(x, "cmd_dir", None)} for x in shared.face_restorers] def get_detailers(): shared.yolo.enumerate() return [{"name": k, "path": v} for k, v in shared.yolo.list.items()] def get_prompt_styles(): return [{ 'name': v.name, 'prompt': v.prompt, 'negative_prompt': v.negative_prompt, 'extra': v.extra, 'filename': v.filename, 'preview': v.preview} for v in shared.prompt_styles.styles.values()] def get_embeddings(): db = getattr(shared.sd_model, 'embedding_db', None) if shared.sd_loaded else None if db is None: return models.ResEmbeddings(loaded=[], skipped=[]) return models.ResEmbeddings(loaded=list(db.word_embeddings.keys()), skipped=list(db.skipped_embeddings.keys())) def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin res = [] for pg in shared.extra_networks: if page is not None and pg.name != page.lower(): continue for item in pg.items: if name is not None and item.get('name', '') != name: continue if title is not None and item.get('title', '') != title: continue if filename is not None and item.get('filename', '') != filename: continue if fullname is not None and item.get('fullname', '') != fullname: continue if hash is not None and (item.get('shorthash', None) or item.get('hash')) != hash: continue res.append({ 'name': item.get('name', ''), 'type': pg.name, 'title': item.get('title', None), 'fullname': item.get('fullname', None), 'filename': item.get('filename', None), 'hash': item.get('shorthash', None) or item.get('hash'), "preview": item.get('preview', None), }) return res def get_interrogate(): from modules.interrogate.openclip import refresh_clip_models return ['deepdanbooru'] + refresh_clip_models() def get_schedulers(): from modules.sd_samplers import list_samplers all_schedulers = list_samplers() for s in all_schedulers: shared.log.critical(s) return all_schedulers def post_interrogate(req: models.ReqInterrogate): if req.image is None or len(req.image) < 64: raise HTTPException(status_code=404, detail="Image not found") image = helpers.decode_base64_to_image(req.image) image = image.convert('RGB') if req.model == "deepdanbooru" or req.model == 'deepbooru': from modules.interrogate import deepbooru caption = deepbooru.model.tag(image) return models.ResInterrogate(caption=caption) else: from modules.interrogate.openclip import interrogate_image, analyze_image, refresh_clip_models if req.model not in refresh_clip_models(): raise HTTPException(status_code=404, detail="Model not found") try: caption = interrogate_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode) except Exception as e: caption = str(e) if not req.analyze: return models.ResInterrogate(caption=caption) else: medium, artist, movement, trending, flavor, _ = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model) return models.ResInterrogate(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor) def post_vqa(req: models.ReqVQA): if req.image is None or len(req.image) < 64: raise HTTPException(status_code=404, detail="Image not found") image = helpers.decode_base64_to_image(req.image) image = image.convert('RGB') from modules.interrogate import vqa answer = vqa.interrogate(req.question, req.system, '', image, req.model) return models.ResVQA(answer=answer) def post_unload_checkpoint(): from modules import sd_models sd_models.unload_model_weights(op='model') sd_models.unload_model_weights(op='refiner') return {} def post_reload_checkpoint(force:bool=False): from modules import sd_models if force: sd_models.unload_model_weights(op='model') sd_models.reload_model_weights() return {} def post_lock_checkpoint(lock:bool=False): from modules import modeldata modeldata.model_data.locked = lock return {} def get_checkpoint(): if not shared.sd_loaded or shared.sd_model is None: checkpoint = { 'type': None, 'class': None, } else: checkpoint = { 'type': shared.sd_model_type, 'class': shared.sd_model.__class__.__name__, } if hasattr(shared.sd_model, 'sd_model_checkpoint'): checkpoint['checkpoint'] = shared.sd_model.sd_model_checkpoint if hasattr(shared.sd_model, 'sd_checkpoint_info'): checkpoint['title'] = shared.sd_model.sd_checkpoint_info.title checkpoint['name'] = shared.sd_model.sd_checkpoint_info.name checkpoint['filename'] = shared.sd_model.sd_checkpoint_info.filename checkpoint['hash'] = shared.sd_model.sd_checkpoint_info.shorthash return checkpoint def set_checkpoint(sd_model_checkpoint: str, dtype:str=None, force:bool=False): from modules import sd_models, devices if force: sd_models.unload_model_weights(op='model') if dtype is not None: shared.opts.cuda_dtype = dtype devices.set_dtype() shared.opts.sd_model_checkpoint = sd_model_checkpoint model = sd_models.reload_model_weights() return { 'ok': model is not None } def post_refresh_checkpoints(): shared.refresh_checkpoints() return {} def post_refresh_vae(): shared.refresh_vaes() return {} def get_modules(): from modules import modelstats model = modelstats.analyze() if model is None: return {} model_obj = { 'model': model.name, 'type': model.type, 'class': model.cls, 'size': model.size, 'mtime': str(model.mtime), 'modules': [] } for m in model.modules: model_obj['modules'].append({ 'class': m.cls, 'params': m.params, 'modules': m.modules, 'quant': m.quant, 'device': str(m.device), 'dtype': str(m.dtype) }) return model_obj def get_extensions_list(): from modules import extensions extensions.list_extensions() ext_list = [] for ext in extensions.extensions: ext: extensions.Extension ext.read_info() if ext.remote is not None: ext_list.append({ "name": ext.name, "remote": ext.remote, "branch": ext.branch, "commit_hash":ext.commit_hash, "commit_date":ext.commit_date, "version":ext.version, "enabled":ext.enabled }) return ext_list def post_pnginfo(req: models.ReqImageInfo): from modules import images, script_callbacks, infotext if not req.image.strip(): return models.ResImageInfo(info="") image = helpers.decode_base64_to_image(req.image.strip()) if image is None: return models.ResImageInfo(info="") geninfo, items = images.read_info_from_image(image) if geninfo is None: geninfo = "" params = infotext.parse(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.ResImageInfo(info=geninfo, items=items, parameters=params) def get_latent_history(): return shared.history.list def post_latent_history(req: models.ReqLatentHistory): shared.history.index = shared.history.find(req.name) return shared.history.index ================================================ FILE: modules/api/gallery.py ================================================ import io import os import time import base64 from typing import List, Union from urllib.parse import quote, unquote from fastapi import FastAPI from fastapi.responses import JSONResponse from starlette.websockets import WebSocket, WebSocketState from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from PIL import Image from modules import shared, images, files_cache, modelstats from modules.paths import resolve_output_path debug = shared.log.debug if os.environ.get('SD_BROWSER_DEBUG', None) is not None else lambda *args, **kwargs: None OPTS_FOLDERS = [ "outdir_samples", "outdir_txt2img_samples", "outdir_img2img_samples", "outdir_control_samples", "outdir_extras_samples", "outdir_save", "outdir_video", "outdir_init_images", "outdir_grids", "outdir_txt2img_grids", "outdir_img2img_grids", "outdir_control_grids", ] ### class definitions class ReqFiles(BaseModel): folder: str = Field(title="Folder") ### ws connection manager class ConnectionManager: def __init__(self): self.active: list[WebSocket] = [] async def connect(self, ws: WebSocket): await ws.accept() agent = ws._headers.get("user-agent", "") # pylint: disable=protected-access debug(f'Browser WS connect: client={ws.client.host} agent="{agent}"') self.active.append(ws) def disconnect(self, ws: WebSocket): debug(f'Browser WS disconnect: client={ws.client.host}') self.active.remove(ws) async def send(self, ws: WebSocket, data: Union[str, dict, bytes]): # debug(f'Browser WS send: client={ws.client.host} data={type(data)}') if ws.client_state != WebSocketState.CONNECTED: return if isinstance(data, bytes): await ws.send_bytes(data) elif isinstance(data, dict): await ws.send_json(data) elif isinstance(data, str): await ws.send_text(data) else: debug(f'Browser WS send: client={ws.client.host} data={type(data)} unknown') async def broadcast(self, data: Union[str, dict, bytes]): for ws in self.active: await self.send(ws, data) ### api definitions def register_api(app: FastAPI): # register api manager = ConnectionManager() def get_video_thumbnail(filepath): from modules.video import get_video_params try: stat_size, stat_mtime = modelstats.stat(filepath) frames, fps, duration, width, height, codec, frame = get_video_params(filepath, capture=True) h = shared.opts.extra_networks_card_size w = shared.opts.extra_networks_card_size if shared.opts.browser_fixed_width else width * h // height frame = frame.convert('RGB') frame.thumbnail((w, h), Image.Resampling.HAMMING) buffered = io.BytesIO() frame.save(buffered, format='jpeg') data_url = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}' frame.close() content = { 'exif': f'Codec: {codec}, Frames: {frames}, Duration: {duration:.2f} sec, FPS: {fps:.2f}', 'data': data_url, 'width': width, 'height': height, 'size': stat_size, 'mtime': stat_mtime.timestamp() * 1000, # JS timestamps use milliseconds } return content except Exception as e: shared.log.error(f'Gallery video: file="{filepath}" {e}') return {} def get_image_thumbnail(filepath): try: stat_size, stat_mtime = modelstats.stat(filepath) image = Image.open(filepath) geninfo, _items = images.read_info_from_image(image) h = shared.opts.extra_networks_card_size w = shared.opts.extra_networks_card_size if shared.opts.browser_fixed_width else image.width * h // image.height width, height = image.width, image.height image = image.convert('RGB') image.thumbnail((w, h), Image.Resampling.HAMMING) buffered = io.BytesIO() image.save(buffered, format='jpeg') data_url = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}' image.close() content = { 'exif': geninfo, 'data': data_url, 'width': width, 'height': height, 'size': stat_size, 'mtime': stat_mtime.timestamp() * 1000, # JS timestamps use milliseconds } return content except Exception as e: shared.log.error(f'Gallery image: file="{filepath}" {e}') return {} # @app.get('/sdapi/v1/browser/folders', response_model=List[str]) def get_folders(): def make_folder(path, label=None): """Create folder entry with path and display label.""" if label is None: label = os.path.basename(path) or path return {"path": path, "label": label} reference_dir = os.path.join('models', 'Reference') base_samples = shared.opts.outdir_samples base_grids = shared.opts.outdir_grids # Build list of resolved output paths with labels folders = [] if base_samples: folders.append(make_folder(base_samples, os.path.basename(base_samples.rstrip('/\\')))) if base_grids and base_grids != base_samples: folders.append(make_folder(base_grids, os.path.basename(base_grids.rstrip('/\\')))) # Use the specific folder setting values as labels (e.g., "outputs/text" -> "outputs/text") folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_txt2img_samples), shared.opts.outdir_txt2img_samples)) folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_img2img_samples), shared.opts.outdir_img2img_samples)) folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_control_samples), shared.opts.outdir_control_samples)) folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_extras_samples), shared.opts.outdir_extras_samples)) folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_save), shared.opts.outdir_save)) folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_video), shared.opts.outdir_video)) folders.append(make_folder(resolve_output_path(base_samples, shared.opts.outdir_init_images), shared.opts.outdir_init_images)) folders.append(make_folder(resolve_output_path(base_grids, shared.opts.outdir_txt2img_grids), shared.opts.outdir_txt2img_grids)) folders.append(make_folder(resolve_output_path(base_grids, shared.opts.outdir_img2img_grids), shared.opts.outdir_img2img_grids)) folders.append(make_folder(resolve_output_path(base_grids, shared.opts.outdir_control_grids), shared.opts.outdir_control_grids)) # Custom browser folders and reference dir for f in shared.opts.browser_folders.split(','): f = f.strip() if f: folders.append(make_folder(f)) folders.append(make_folder(reference_dir, 'Reference')) # Filter empty and duplicates (by path) seen_paths = set() unique_folders = [] for f in folders: path = f["path"].strip() if path and path not in seen_paths and os.path.isdir(path): seen_paths.add(path) unique_folders.append(f) if shared.demo is not None and path not in shared.demo.allowed_paths: debug(f'Browser folders allow: {path}') shared.demo.allowed_paths.append(quote(path)) debug(f'Browser folders: {unique_folders}') return JSONResponse(content=unique_folders) # @app.get("/sdapi/v1/browser/thumb", response_model=dict) async def get_thumb(file: str): try: decoded = unquote(file).replace('%3A', ':') if decoded.lower().endswith('.mp4'): return JSONResponse(content=get_video_thumbnail(decoded)) else: return JSONResponse(content=get_image_thumbnail(decoded)) except Exception as e: shared.log.error(f'Gallery: {file} {e}') content = { 'error': str(e) } return JSONResponse(content=content) # @app.get("/sdapi/v1/browser/files", response_model=list) async def ht_files(folder: str): try: t0 = time.time() files = files_cache.directory_files(folder, recursive=True) lines = [] for f in files: file = os.path.relpath(f, folder) msg = quote(folder) + '##F##' + quote(file) msg = msg[:1] + ":" + msg[4:] if msg[1:4] == "%3A" else msg lines.append(msg) t1 = time.time() shared.log.debug(f'Gallery: type=ht folder="{folder}" files={len(lines)} time={t1-t0:.3f}') return lines except Exception as e: shared.log.error(f'Gallery: {folder} {e}') return [] shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=List[str]) shared.api.add_api_route("/sdapi/v1/browser/thumb", get_thumb, methods=["GET"], response_model=dict) shared.api.add_api_route("/sdapi/v1/browser/files", ht_files, methods=["GET"], response_model=list) @app.websocket("/sdapi/v1/browser/files") async def ws_files(ws: WebSocket): try: await manager.connect(ws) folder = await ws.receive_text() folder = unquote(folder).replace('%3A', ':') t0 = time.time() numFiles = 0 files = files_cache.list_files(folder, recursive=True) # files = list(files_cache.directory_files(folder, recursive=True)) # files.sort(key=os.path.getmtime) for f in files: numFiles += 1 file = os.path.relpath(f, folder) msg = quote(folder) + '##F##' + quote(file) msg = msg[:1] + ":" + msg[4:] if msg[1:4] == "%3A" else msg await manager.send(ws, msg) await manager.send(ws, '#END#') t1 = time.time() shared.log.debug(f'Gallery: type=ws folder="{folder}" files={numFiles} time={t1-t0:.3f}') except Exception as e: debug(f'Browser WS error: {e}') manager.disconnect(ws) ================================================ FILE: modules/api/generate.py ================================================ from threading import Lock from fastapi.responses import JSONResponse from modules import errors, shared, scripts_manager, ui from modules.api import models, script, helpers from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.paths import resolve_output_path errors.install() class APIGenerate(): def __init__(self, queue_lock: Lock): self.queue_lock = queue_lock self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] self.default_script_arg_control = [] def sanitize_args(self, args: dict): args = vars(args) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) args.pop('face', None) args.pop('face_id', None) args.pop('save_images', None) return args def sanitize_b64(self, request): def sanitize_str(args: list): for idx in range(0, len(args)): if isinstance(args[idx], str) and len(args[idx]) >= 1000: args[idx] = f"" if hasattr(request, "alwayson_scripts") and request.alwayson_scripts: for script_name in request.alwayson_scripts.keys(): script_obj = request.alwayson_scripts[script_name] if script_obj and "args" in script_obj and script_obj["args"]: sanitize_str(script_obj["args"]) if hasattr(request, "script_args") and request.script_args: sanitize_str(request.script_args) def prepare_face_module(self, request): if getattr(request, "face", None) is not None and (not request.alwayson_scripts or "face" not in request.alwayson_scripts.keys()): request.script_name = "face" request.script_args = [ request.face.mode, request.face.source_images, request.face.ip_model, request.face.ip_override_sampler, request.face.ip_cache_model, request.face.ip_strength, request.face.ip_structure, request.face.id_strength, request.face.id_conditioning, request.face.id_cache, request.face.pm_trigger, request.face.pm_strength, request.face.pm_start, request.face.fs_cache ] del request.face def prepare_ip_adapter(self, request, p): if hasattr(request, "ip_adapter") and request.ip_adapter: p.ip_adapter_names = [] p.ip_adapter_scales = [] p.ip_adapter_crops = [] p.ip_adapter_starts = [] p.ip_adapter_ends = [] p.ip_adapter_images = [] for ipadapter in request.ip_adapter: if not ipadapter.images or len(ipadapter.images) == 0: continue p.ip_adapter_names.append(ipadapter.adapter) p.ip_adapter_scales.append(ipadapter.scale) p.ip_adapter_crops.append(ipadapter.crop) p.ip_adapter_starts.append(ipadapter.start) p.ip_adapter_ends.append(ipadapter.end) p.ip_adapter_images.append([helpers.decode_base64_to_image(x) for x in ipadapter.images]) p.ip_adapter_masks = [] if ipadapter.masks: p.ip_adapter_masks.append([helpers.decode_base64_to_image(x) for x in ipadapter.masks]) del request.ip_adapter def post_text2img(self, txt2imgreq: models.ReqTxt2Img): self.prepare_face_module(txt2imgreq) script_runner = scripts_manager.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) ui.create_ui(None) if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = script.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = script.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": helpers.validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": not txt2imgreq.save_images, "do_not_save_grid": not txt2imgreq.save_images, }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = self.sanitize_args(populate) send_images = args.pop('send_images', True) with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) self.prepare_ip_adapter(txt2imgreq, p) p.scripts = script_runner p.outpath_grids = resolve_output_path(shared.opts.outdir_grids, shared.opts.outdir_txt2img_grids) p.outpath_samples = resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_txt2img_samples) for key, value in getattr(txt2imgreq, "extra", {}).items(): setattr(p, key, value) jobid = shared.state.begin('API-TXT', api=True) script_args = script.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) p.script_args = tuple(script_args) # Need to pass args as tuple here if selectable_scripts is not None: processed = scripts_manager.scripts_txt2img.run(p, *script_args) # Need to pass args as list here else: processed = process_images(p) processed = scripts_manager.scripts_txt2img.after(p, processed, *script_args) p.close() shared.state.end(jobid) if processed is None or processed.images is None or len(processed.images) == 0: b64images = [] else: b64images = list(map(helpers.encode_pil_to_base64, processed.images)) if send_images else [] self.sanitize_b64(txt2imgreq) info = processed.js() if processed else '' return models.ResTxt2Img(images=b64images, parameters=vars(txt2imgreq), info=info) def post_img2img(self, img2imgreq: models.ReqImg2Img): self.prepare_face_module(img2imgreq) init_images = img2imgreq.init_images if init_images is None: return JSONResponse(status_code=400, content={"error": "Init image is none"}) mask = img2imgreq.mask if mask: mask = helpers.decode_base64_to_image(mask) script_runner = scripts_manager.scripts_img2img if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui(None) if not self.default_script_arg_img2img: self.default_script_arg_img2img = script.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = script.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params "sampler_name": helpers.validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": not img2imgreq.save_images, "do_not_save_grid": not img2imgreq.save_images, "mask": mask, }) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on args = self.sanitize_args(populate) send_images = args.pop('send_images', True) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) self.prepare_ip_adapter(img2imgreq, p) p.init_images = [helpers.decode_base64_to_image(x) for x in init_images] p.scripts = script_runner p.outpath_grids = resolve_output_path(shared.opts.outdir_grids, shared.opts.outdir_img2img_grids) p.outpath_samples = resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_img2img_samples) for key, value in getattr(img2imgreq, "extra", {}).items(): setattr(p, key, value) jobid = shared.state.begin('API-IMG', api=True) script_args = script.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) p.script_args = tuple(script_args) # Need to pass args as tuple here if selectable_scripts is not None: processed = scripts_manager.scripts_img2img.run(p, *script_args) # Need to pass args as list here else: processed = process_images(p) processed = scripts_manager.scripts_img2img.after(p, processed, *script_args) p.close() shared.state.end(jobid) if processed is None or processed.images is None or len(processed.images) == 0: b64images = [] else: b64images = list(map(helpers.encode_pil_to_base64, processed.images)) if send_images else [] if not img2imgreq.include_init_images: img2imgreq.init_images = None img2imgreq.mask = None self.sanitize_b64(img2imgreq) info = processed.js() if processed else '' return models.ResImg2Img(images=b64images, parameters=vars(img2imgreq), info=info) ================================================ FILE: modules/api/gpu.py ================================================ import torch from installer import log device = None def get_gpu_status(): global device # pylint: disable=global-statement if device is None: try: device = torch.cuda.get_device_name(torch.cuda.current_device()) log.info(f'GPU monitoring: device={device}') except Exception: device = '' # per vendor modules if 'nvidia' in device.lower(): from modules.api import nvml return nvml.get_nvml() elif 'amd' in device.lower(): from modules.api import rocm_smi return rocm_smi.get_rocm_smi() elif 'arc' in device.lower(): from modules.api import xpu_smi return xpu_smi.get_xpu_smi() return [] """ Resut should always be: list[ResGPU] class ResGPU(BaseModel): name: str = Field(title="GPU Name") data: dict = Field(title="Name/Value data") chart: list[float, float] = Field(title="Exactly two items to place on chart") """ if __name__ == '__main__': from rich import print as rprint for gpu in get_gpu_status(): rprint(gpu) ================================================ FILE: modules/api/helpers.py ================================================ import io import base64 from PIL import Image, PngImagePlugin import piexif import piexif.helper from fastapi.exceptions import HTTPException from modules import shared, sd_samplers def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) if config is None: raise HTTPException(status_code=404, detail="Sampler not found") return name def decode_base64_to_image(encoding, quiet=False): if encoding is None: return None if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] try: decoded = base64.b64decode(encoding) data = io.BytesIO(decoded) image = Image.open(data) return image except Exception as e: shared.log.warning(f'API cannot decode image: {e}') # from modules import errors # errors.display(e, 'API cannot decode image') if not quiet: raise HTTPException(status_code=500, detail="Invalid encoded image") from e return None def encode_pil_to_base64(image): """ with io.BytesIO() as output_bytes: images.save_image(image, output_bytes, shared.opts.samples_format) bytes_data = output_bytes.getvalue() return base64.b64encode(bytes_data) """ if not isinstance(image, Image.Image): shared.log.error('API cannot encode image: not a PIL image') return '' buffered = io.BytesIO() save_image(image, fn=buffered, ext=shared.opts.samples_format) b64 = base64.b64encode(buffered.getvalue()) return b64 def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e def save_image(image, fn, ext): # actual save parameters = image.info.get('parameters', None) image_format = Image.registered_extensions()[f'.{ext}'] if image_format == 'PNG': pnginfo_data = PngImagePlugin.PngInfo() for k, v in image.info.items(): pnginfo_data.add_text(k, str(v)) image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, pnginfo=pnginfo_data) elif image_format == 'JPEG': if image.mode == 'RGBA': shared.log.warning('Save: RGBA image as JPEG - removed alpha channel') image = image.convert("RGB") elif image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("L") elif image.mode == 'P': image = image.convert("RGB") exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } }) image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, exif=exif_bytes) elif image_format == 'WEBP': if image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } }) image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes) elif image_format == 'JXL': if image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") elif image.mode not in {"RGB", "RGBA"}: image = image.convert("RGBA") exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } }) image.save(fn, format=image_format, quality=shared.opts.jpeg_quality, lossless=shared.opts.webp_lossless, exif=exif_bytes) else: # shared.log.warning(f'Unrecognized image format: {extension} attempting save as {image_format}') image.save(fn, format=image_format, quality=shared.opts.jpeg_quality) ================================================ FILE: modules/api/loras.py ================================================ from typing import List from fastapi.exceptions import HTTPException def get_lora(lora: str) -> dict: from modules.lora import lora_load if lora not in lora_load.available_networks: raise HTTPException(status_code=404, detail=f"Lora '{lora}' not found") obj = lora_load.available_networks[lora] return obj.__dict__ def get_loras(): from modules.lora import network, lora_load def create_lora_json(obj: network.NetworkOnDisk): return { "name": obj.name, "alias": obj.alias, "path": obj.filename, "metadata": obj.metadata } return [create_lora_json(obj) for obj in lora_load.available_networks.values()] def post_refresh_loras(): from modules.lora import lora_load return lora_load.list_available_networks() def register_api(): from modules.shared import api api.add_api_route("/sdapi/v1/lora", get_lora, methods=["GET"], response_model=dict) api.add_api_route("/sdapi/v1/loras", get_loras, methods=["GET"], response_model=List[dict]) api.add_api_route("/sdapi/v1/refresh-loras", post_refresh_loras, methods=["POST"]) ================================================ FILE: modules/api/middleware.py ================================================ import ssl import time import logging from asyncio.exceptions import CancelledError import anyio import starlette import uvicorn import fastapi from starlette.responses import JSONResponse from fastapi import FastAPI, Request, Response from fastapi.exceptions import HTTPException from fastapi.encoders import jsonable_encoder from installer import log import modules.errors as errors errors.install() ignore_endpoints = [ '/sdapi/v1/log', '/sdapi/v1/browser', '/sdapi/v1/gpu', '/sdapi/v1/network/thumb', '/sdapi/v1/progress', ] def setup_middleware(app: FastAPI, cmd_opts): ssl._create_default_https_context = ssl._create_unverified_context # pylint: disable=protected-access uvicorn_logger=logging.getLogger("uvicorn.error") uvicorn_logger.disabled = True from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.middleware_stack = None # reset current middleware to allow modifying user provided list app.add_middleware(GZipMiddleware, minimum_size=2048) if cmd_opts.cors_origins and cmd_opts.cors_regex: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_origins.split(','), allow_origin_regex=cmd_opts.cors_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_origins: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_regex: app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) @app.middleware("http") async def log_and_time(req: Request, call_next): try: ts = time.time() res: Response = await call_next(req) duration = str(round(time.time() - ts, 4)) res.headers["X-Process-Time"] = duration endpoint = req.scope.get('path', 'err') token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure") if (cmd_opts.api_log) and endpoint.startswith('/sdapi'): if any([endpoint.startswith(x) for x in ignore_endpoints]): # noqa C419 # pylint: disable=use-a-generator return res log.info('API user={user} code={code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation user = app.tokens.get(token) if hasattr(app, 'tokens') else None, code = res.status_code, ver = req.scope.get('http_version', '0.0'), cli = req.scope.get('client', ('0:0.0.0', 0))[0], prot = req.scope.get('scheme', 'err'), method = req.scope.get('method', 'err'), endpoint = endpoint, duration = duration, )) return res except CancelledError: log.warning('WebSocket closed (ignore asyncio.exceptions.CancelledError)') except BaseException as e: return handle_exception(req, e) def handle_exception(req: Request, e: Exception): err = { "error": type(e).__name__, "code": vars(e).get('status_code', 500), "detail": vars(e).get('detail', ''), "body": vars(e).get('body', ''), "errors": str(e), } if err['code'] == 401 and 'file=' in req.url.path: # dont spam with unauth return JSONResponse(status_code=err['code'], content=jsonable_encoder(err)) if err['code'] == 404 and 'file=html/' in req.url.path: # dont spam with locales return JSONResponse(status_code=err['code'], content=jsonable_encoder(err)) if not any([req.url.path.endswith(x) for x in ignore_endpoints]): # noqa C419 # pylint: disable=use-a-generator log.error(f"API error: {req.method}: {req.url} {err}") if not isinstance(e, HTTPException) and err['error'] != 'TypeError': # do not print backtrace on known httpexceptions errors.display(e, 'HTTP API', [anyio, fastapi, uvicorn, starlette]) elif err['code'] in [404, 401, 400]: pass else: log.debug(e, exc_info=True) # print stack trace return JSONResponse(status_code=err['code'], content=jsonable_encoder(err)) @app.exception_handler(HTTPException) async def http_exception_handler(req: Request, e: HTTPException): return handle_exception(req, e) @app.exception_handler(Exception) async def general_exception_handler(req: Request, e: Exception): if isinstance(e, TypeError): return JSONResponse(status_code=500, content=jsonable_encoder(str(e))) else: return handle_exception(req, e) app.build_middleware_stack() # rebuild middleware stack on-the-fly log.debug(f'API middleware: {[m.cls for m in app.user_middleware]}') ================================================ FILE: modules/api/mime.py ================================================ import mimetypes def register(): mimetypes.init() mimetypes.add_type('application/javascript', '.js') mimetypes.add_type('application/javascript', '.mjs') mimetypes.add_type('application/json', '.map') mimetypes.add_type('text/html', '.html') mimetypes.add_type('image/webp', '.webp') mimetypes.add_type('image/jxl', '.jxl') mimetypes.add_type('font/ttf', '.ttf') ================================================ FILE: modules/api/models.py ================================================ import inspect from typing import Any, Optional, Dict, List, Type, Callable, Union from pydantic import BaseModel, Field, create_model # pylint: disable=no-name-in-module from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img import modules.shared as shared API_NOT_ALLOWED = [ "self", "kwargs", "sd_model", "outpath_samples", "outpath_grids", ] class ModelDef(BaseModel): field: str field_alias: str field_type: Any field_value: Any field_exclude: bool = False class DummyConfig: dummy_value = None if not hasattr(BaseModel, "__config__"): BaseModel.__config__ = DummyConfig class PydanticModelGenerator: def __init__( self, model_name: str = None, class_instance = None, additional_fields = None, exclude_fields: List = [], ): def field_type_generator(_k, v): field_type = v.annotation return Optional[field_type] def merge_class_params(class_): all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) parameters = {} for classes in all_classes: parameters = {**parameters, **inspect.signature(classes.__init__).parameters} return parameters self._model_name = model_name self._class_data = merge_class_params(class_instance) self._model_def = [ ModelDef( field=underscore(k), field_alias=k, field_type=field_type_generator(k, v), field_value=v.default ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] for fld in additional_fields: self._model_def.append(ModelDef( field=underscore(fld["key"]), field_alias=fld["key"], field_type=fld["type"], field_value=fld["default"], field_exclude=fld["exclude"] if "exclude" in fld else False)) for fld in exclude_fields: self._model_def = [x for x in self._model_def if x.field != fld] def generate_model(self): model_fields = { d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def } DynamicModel = create_model(self._model_name, **model_fields) try: DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_mutation = True except Exception: pass return DynamicModel ### item classes class ItemSampler(BaseModel): name: str = Field(title="Name") options: dict class ItemVae(BaseModel): model_name: str = Field(title="Model Name") filename: str = Field(title="Filename") class ItemUpscaler(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") model_url: Optional[str] = Field(title="URL") scale: Optional[float] = Field(title="Scale") class ItemModel(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") filename: str = Field(title="Filename") type: str = Field(title="Model type") sha256: Optional[str] = Field(title="SHA256 hash") hash: Optional[str] = Field(title="Short hash") config: Optional[str] = Field(title="Config file") class ItemHypernetwork(BaseModel): name: str = Field(title="Name") path: Optional[str] = Field(title="Path") class ItemDetailer(BaseModel): name: str = Field(title="Name") path: Optional[str] = Field(title="Path") class ItemGAN(BaseModel): name: str = Field(title="Name") path: Optional[str] = Field(title="Path") scale: Optional[int] = Field(title="Scale") class ItemStyle(BaseModel): name: str = Field(title="Name") prompt: Optional[str] = Field(title="Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt") extra: Optional[str] = Field(title="Extra") filename: Optional[str] = Field(title="Filename") preview: Optional[str] = Field(title="Preview") class ItemExtraNetwork(BaseModel): name: str = Field(title="Name") type: str = Field(title="Type") title: Optional[str] = Field(title="Title") fullname: Optional[str] = Field(title="Fullname") filename: Optional[str] = Field(title="Filename") hash: Optional[str] = Field(title="Hash") preview: Optional[str] = Field(title="Preview image URL") class ItemArtist(BaseModel): name: str = Field(title="Name") score: float = Field(title="Score") category: str = Field(title="Category") class ItemEmbedding(BaseModel): step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available") sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead") shape: int = Field(title="Shape", description="The length of each individual vector in the embedding") vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") class ItemIPAdapter(BaseModel): adapter: str = Field(title="Adapter", default="Base", description="IP adapter name") images: List[str] = Field(title="Image", default=[], description="IP adapter input images") masks: Optional[List[str]] = Field(title="Mask", default=[], description="IP adapter mask images") scale: float = Field(title="Scale", default=0.5, ge=0, le=1, description="IP adapter scale") start: float = Field(title="Start", default=0.0, ge=0, le=1, description="IP adapter start step") end: float = Field(title="End", default=1.0, gt=0, le=1, description="IP adapter end step") crop: bool = Field(title="Crop", default=False, description="IP adapter crop face from input") class ItemFace(BaseModel): mode: str = Field(title="Mode", default="FaceID", description="The mode to use (available values: FaceID, FaceSwap, PhotoMaker, InstantID).") source_images: list[str] = Field(title="Source Images", description="Source face images, must be base64 encoded containing the image's data.") ip_model: str = Field(title="IPAdapter Model", default="FaceID Base", description="The IPAdapter model to use.") ip_override_sampler: bool = Field(title="IPAdapter Override Sampler", default=True, description="Should the sampler be overriden?") ip_cache_model: bool = Field(title="IPAdapter Cache", default=True, description="Should the IPAdapter model be cached?") ip_strength: float = Field(title="IPAdapter Strength", default=1, ge=0, le=2, description="IPAdapter strength of the source images, must be between 0.0 and 2.0.") ip_structure: float = Field(title="IPAdapter Structure", default=1, ge=0, le=1, description="IPAdapter structure to use, must be between 0.0 and 1.0.") id_strength: float = Field(title="InstantID Strength", default=1, ge=0, le=2, description="InstantID Strength of the source images, must be between 0.0 and 2.0.") id_conditioning: float = Field(title="InstantID Condition", default=0.5, ge=0, le=2, description="InstantID control amount, must be between 0.0 and 2.0.") id_cache: bool = Field(title="InstantID Cache", default=True, description="Should the InstantID model be cached?") pm_trigger: str = Field(title="PhotoMaker Trigger", default="person", description="PhotoMaker trigger word to use.") pm_strength: float = Field(title="PhotoMaker Strength", default=1, ge=0, le=2, description="PhotoMaker strength to use, must be between 0.0 and 2.0.") pm_start: float = Field(title="PhotoMaker Start", default=0.5, ge=0, le=1, description="PhotoMaker start value, must be between 0.0 and 1.0.") fs_cache: bool = Field(title="FaceSwap Cache", default=True, description="Should the FaceSwap model be cached?") class ScriptArg(BaseModel): label: str = Field(default=None, title="Label", description="Name of the argument in UI") value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument") minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") choices: Optional[Any] = Field(default=None, title="Choices", description="Possible values for the argument") class ItemScript(BaseModel): name: str = Field(default=None, title="Name", description="Script name") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") class ItemExtension(BaseModel): name: str = Field(title="Name", description="Extension name") remote: str = Field(title="Remote", description="Extension Repository URL") branch: str = Field(default="uknnown", title="Branch", description="Extension Repository Branch") commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash") version: str = Field(title="Version", description="Extension Version") commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date") enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled") class ItemScheduler(BaseModel): name: str = Field(title="Name", description="Scheduler name") cls: str = Field(title="Class", description="Scheduler class name") options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options") ### request/response classes ReqTxt2Img = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, [ {"key": "sampler_index", "type": Union[int, str], "default": 0}, {"key": "sampler_name", "type": str, "default": "Default"}, {"key": "hr_sampler_name", "type": str, "default": "Same as primary"}, {"key": "script_name", "type": Optional[str], "default": ""}, {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True}, {"key": "face", "type": Optional[ItemFace], "default": None, "exclude": True}, {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True}, ] ).generate_model() if not hasattr(ReqTxt2Img, "__config__"): ReqTxt2Img.__config__ = DummyConfig StableDiffusionTxt2ImgProcessingAPI = ReqTxt2Img class ResTxt2Img(BaseModel): images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.") parameters: dict info: str ReqImg2Img = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, [ {"key": "sampler_index", "type": Union[int, str], "default": 0}, {"key": "sampler_name", "type": str, "default": "UniPC"}, {"key": "hr_sampler_name", "type": str, "default": "Same as primary"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.5}, {"key": "mask", "type": Optional[str], "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude": True}, {"key": "script_name", "type": Optional[str], "default": ""}, {"key": "script_args", "type": list, "default": []}, {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "ip_adapter", "type": Optional[List[ItemIPAdapter]], "default": None, "exclude": True}, {"key": "face_id", "type": Optional[ItemFace], "default": None, "exclude": True}, {"key": "extra", "type": Optional[dict], "default": {}, "exclude": True}, ] ).generate_model() if not hasattr(ReqImg2Img, "__config__"): ReqImg2Img.__config__ = DummyConfig StableDiffusionImg2ImgProcessingAPI = ReqImg2Img class ResImg2Img(BaseModel): images: List[str] = Field(default=None, title="Image", description="The generated images in base64 format.") parameters: dict info: str class FileData(BaseModel): data: str = Field(title="File data", description="Base64 representation of the file") name: str = Field(title="File name") class ReqProcess(BaseModel): resize_mode: float = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?") gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in shared.sd_upscalers])}") upscaler_2: str = Field(default="None", title="Refine upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in shared.sd_upscalers])}") extras_upscaler_2_visibility: float = Field(default=0, title="Refine upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") class ResProcess(BaseModel): html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") class ReqPromptEnhance(BaseModel): prompt: str = Field(title="Prompt", description="Prompt to enhance") type: str = Field(title="Type", default='text', description="Type of enhancement to perform") model: Optional[str] = Field(title="Model", default=None, description="Model to use for enhancement") system_prompt: Optional[str] = Field(title="System prompt", default=None, description="Model system prompt") image: Optional[str] = Field(title="Image", default=None, description="Image to work on, must be a Base64 string containing the image's data.") seed: int = Field(title="Seed", default=-1, description="Seed used to generate the prompt") nsfw: bool = Field(title="NSFW", default=True, description="Should NSFW content be allowed?") class ResPromptEnhance(BaseModel): prompt: str = Field(title="Prompt", description="Enhanced prompt") seed: int = Field(title="Seed", description="Seed used to generate the prompt") class ReqProcessImage(ReqProcess): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") class ResProcessImage(ResProcess): image: str = Field(default=None, title="Image", description="The generated image in base64 format.") class ReqProcessBatch(ReqProcess): imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ResProcessBatch(ResProcess): images: List[str] = Field(title="Images", description="The generated images in base64 format.") class ReqImageInfo(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") class ResImageInfo(BaseModel): info: str = Field(title="Image info", description="A string with the parameters used to generate the image") items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had") parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields") class ReqGetLog(BaseModel): lines: int = Field(default=100, title="Lines", description="How many lines to return") clear: bool = Field(default=False, title="Clear", description="Should the log be cleared after returning the lines?") class ReqPostLog(BaseModel): message: Optional[str] = Field(default=None, title="Message", description="The info message to log") debug: Optional[str] = Field(default=None, title="Debug message", description="The debug message to log") error: Optional[str] = Field(default=None, title="Error message", description="The error message to log") class ReqHistory(BaseModel): id: Union[int, str, None] = Field(default=None, title="Task ID", description="Task ID") class ReqProgress(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") class ResProgress(BaseModel): id: Union[int, str, None] = Field(title="TaskID", description="Task ID") progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: Optional[str] = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") textinfo: Optional[str] = Field(default=None, title="Info text", description="Info text used by WebUI.") class ResHistory(BaseModel): id: Union[int, str, None] = Field(title="ID", description="Task ID") job: str = Field(title="Job", description="Job name") op: str = Field(title="Operation", description="Job state") timestamp: Union[float, None] = Field(title="Timestamp", description="Job timestamp") duration: Union[float, None] = Field(title="Duration", description="Job duration") outputs: List[str] = Field(title="Outputs", description="List of filenames") class ResStatus(BaseModel): status: str = Field(title="Status", description="Current status") task: str = Field(title="Task", description="Current job") timestamp: Optional[str] = Field(title="Timestamp", description="Timestamp of the current job") current: str = Field(title="Task", description="Current job") id: Union[int, str, None] = Field(title="ID", description="ID of the current task") job: int = Field(title="Job", description="Current job") jobs: int = Field(title="Jobs", description="Total jobs") total: int = Field(title="Total Jobs", description="Total jobs") step: int = Field(title="Step", description="Current step") steps: int = Field(title="Steps", description="Total steps") queued: int = Field(title="Queued", description="Number of queued tasks") uptime: int = Field(title="Uptime", description="Uptime of the server") elapsed: Optional[float] = Field(default=None, title="Elapsed time") eta: Optional[float] = Field(default=None, title="ETA in secs") progress: Optional[float] = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") class ReqInterrogate(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") clip_model: str = Field(default="", title="CLiP Model", description="The interrogate model used.") blip_model: str = Field(default="", title="BLiP Model", description="The interrogate model used.") InterrogateRequest = ReqInterrogate # alias for backwards compatibility class ResInterrogate(BaseModel): caption: Optional[str] = Field(default=None, title="Caption", description="The generated caption for the image.") medium: Optional[str] = Field(default=None, title="Medium", description="Image medium.") artist: Optional[str] = Field(default=None, title="Medium", description="Image artist.") movement: Optional[str] = Field(default=None, title="Medium", description="Image movement.") trending: Optional[str] = Field(default=None, title="Medium", description="Image trending.") flavor: Optional[str] = Field(default=None, title="Medium", description="Image flavor.") class ReqVQA(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") model: str = Field(default="Microsoft Florence 2 Base", title="Model", description="The interrogate model used.") question: str = Field(default="describe the image", title="Question", description="Question to ask the model.") system: str = Field(default="You are image captioning expert, creative, unbiased and uncensored.", title="System prompt", description="Prompt to shape how the model interprets and responds to user prompts.") class ReqLatentHistory(BaseModel): name: str = Field(title="Name", description="Name of the history item to select") class ResVQA(BaseModel): answer: Optional[str] = Field(default=None, title="Answer", description="The generated answer for the image.") class ResTrain(BaseModel): info: str = Field(title="Train info", description="Response string from train embedding task.") class ResCreate(BaseModel): info: str = Field(title="Create info", description="Response string from create embedding task.") class ResPreprocess(BaseModel): info: str = Field(title="Preprocess info", description="Response string from preprocessing task.") fields = {} for key, metadata in shared.opts.data_labels.items(): value = shared.opts.data.get(key) or shared.opts.data_labels[key].default optType = shared.opts.typemap.get(type(metadata.default), type(value)) if metadata is not None: fields.update({key: (Optional[optType], Field( default=metadata.default, description=metadata.label))}) else: fields.update({key: (Optional[optType], Field())}) OptionsModel = create_model("Options", **fields) flags = {} _options = vars(shared.parser)['_option_string_actions'] for key in _options: if _options[key].dest != 'help': flag = _options[key] _type = Optional[str] if _options[key].default is not None: _type = type(_options[key].default) flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))}) FlagsModel = create_model("Flags", **flags) class ResEmbeddings(BaseModel): loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings") skipped: list = Field(default=None, title="skipped", description="List of skipped embeddings") class ResMemory(BaseModel): ram: dict = Field(title="RAM", description="System memory stats") cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") class ResScripts(BaseModel): txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)") img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)") control: list = Field(default=None, title="Control", description="Titles of scripts (control)") class ResGPU(BaseModel): # definition of http response name: str = Field(title="GPU Name") data: dict = Field(title="Name/Value data") chart: list[float, float] = Field(title="Exactly two items to place on chart") # helper function def create_model_from_signature(func: Callable, model_name: str, base_model: Type[BaseModel] = BaseModel, additional_fields: List = [], exclude_fields: List[str] = []) -> type[BaseModel]: from PIL import Image class Config: extra = 'allow' args, _, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = inspect.getfullargspec(func) config = Config if varkw else None # Allow extra params if there is a **kwargs parameter in the function signature defaults = defaults or [] args = args or [] for arg in exclude_fields: if arg in args: args.remove(arg) non_default_args = len(args) - len(defaults) defaults = (...,) * non_default_args + defaults keyword_only_params = {param: kwonlydefaults.get(param, Any) for param in kwonlyargs} for k, v in annotations.items(): if v == List[Image.Image]: annotations[k] = List[str] elif v == Image.Image: annotations[k] = str elif str(v) == 'typing.List[modules.control.unit.Unit]': annotations[k] = List[str] model_fields = {param: (annotations.get(param, Any), default) for param, default in zip(args, defaults)} for fld in additional_fields: model_def = ModelDef( field=underscore(fld["key"]), field_alias=fld["key"], field_type=fld["type"], field_value=fld["default"], field_exclude=fld["exclude"] if "exclude" in fld else False) model_fields[model_def.field] = (model_def.field_type, Field(default=model_def.field_value, alias=model_def.field_alias, exclude=model_def.field_exclude)) for fld in exclude_fields: if fld in model_fields: del model_fields[fld] model = create_model( model_name, **model_fields, **keyword_only_params, __base__=base_model, __config__=config, ) try: model.__config__.allow_population_by_field_name = True model.__config__.allow_mutation = True except Exception: pass return model ================================================ FILE: modules/api/nudenet.py ================================================ from fastapi import Body from modules.api import api def nudenet_censor( image: str = Body("", title='nudenet input image'), score: float = Body(0.2, title='nudenet threshold score'), blocks: int = Body(3, title='nudenet pixelation blocks'), censor: list = Body([], title='nudenet censorship items'), method: str = Body('pixelate', title='nudenet censorship method'), overlay: str = Body('', title='nudenet overlay image path'), ): from scripts.nudenet import nudenet # pylint: disable=no-name-in-module base64image = image image = api.decode_base64_to_image(image) if nudenet.detector is None: nudenet.detector = nudenet.NudeDetector() # loads and initializes model once nudes = nudenet.detector.censor(image=image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay) if len(censor) > 0: # replace image if anything is censored base64image = api.encode_pil_to_base64(nudes.output).decode("utf-8") detections_dict = { d["label"]: d["score"] for d in nudes.detections } return { "image": base64image, "detections": detections_dict } def prompt_check( prompt: str = Body("", title='prompt text'), lang: str = Body("eng", title='allowed languages'), alphabet: str = Body("latn", title='allowed alphabets'), ): from scripts.nudenet import langdetect # pylint: disable=no-name-in-module res = langdetect.lang_detect(prompt) res = ','.join(res) if isinstance(res, list) else res lang = [a.strip() for a in lang.split(',')] if lang else [] alphabet = [a.strip() for a in alphabet.split(',')] if alphabet else [] lang_ok = any(a in res for a in lang) if len(lang) > 0 else True alph_ok = any(a in res for a in alphabet) if len(alphabet) > 0 else True return { "lang": res, "lang_ok": lang_ok, "alph_ok": alph_ok } def image_guard( image: str = Body("", title='input image'), policy: str = Body("", title='optional policy definition'), ): from scripts.nudenet import imageguard # pylint: disable=no-name-in-module image = api.decode_base64_to_image(image) res = imageguard.image_guard(image=image, policy=policy) return res def banned_words( words: str = Body("", title='comma separated list of banned words'), prompt: str = Body("", title='prompt text'), ): from scripts.nudenet import bannedwords # pylint: disable=no-name-in-module found = bannedwords.check_banned(words=words, prompt=prompt) return found def register_api(): from modules.shared import api as api_instance api_instance.add_api_route("/sdapi/v1/nudenet", nudenet_censor, methods=["POST"], response_model=dict) api_instance.add_api_route("/sdapi/v1/prompt-lang", prompt_check, methods=["POST"], response_model=dict) api_instance.add_api_route("/sdapi/v1/image-guard", image_guard, methods=["POST"], response_model=dict) api_instance.add_api_route("/sdapi/v1/prompt-banned", banned_words, methods=["POST"], response_model=list) ================================================ FILE: modules/api/nvml.py ================================================ try: from installer import install, log except Exception: def install(*args, **kwargs): # pylint: disable=unused-argument pass import logging log = logging.getLogger(__name__) nvml_initialized = False warned = False def warn_once(msg): global warned # pylint: disable=global-statement if not warned: log.error(msg) warned = True def get_reason(val): throttle = { 1: 'gpu idle', 2: 'applications clocks setting', 4: 'sw power cap', 8: 'hw slowdown', 16: 'sync boost', 32: 'sw thermal slowdown', 64: 'hw thermal slowdown', 128: 'hw power brake slowdown', 256: 'display clock setting', } reason = ', '.join([throttle[i] for i in throttle if i & val]) return reason if len(reason) > 0 else 'ok' def get_nvml(): global nvml_initialized # pylint: disable=global-statement if warned: return [] try: from modules.memstats import ram_stats if not nvml_initialized: install('nvidia-ml-py', quiet=True) import pynvml # pylint: disable=redefined-outer-name pynvml.nvmlInit() log.debug('NVML initialized') nvml_initialized = True else: import pynvml devices = [] for i in range(pynvml.nvmlDeviceGetCount()): dev = pynvml.nvmlDeviceGetHandleByIndex(i) try: name = pynvml.nvmlDeviceGetName(dev) except Exception: name = '' load = pynvml.nvmlDeviceGetUtilizationRates(dev) mem = pynvml.nvmlDeviceGetMemoryInfo(dev) ram = ram_stats() data = { "CUDA": f'Version {pynvml.nvmlSystemGetCudaDriverVersion()} Compute {pynvml.nvmlDeviceGetCudaComputeCapability(dev)}', "Driver": pynvml.nvmlSystemGetDriverVersion(), "Hardware": f'VBIOS {pynvml.nvmlDeviceGetVbiosVersion(dev)} ROM {pynvml.nvmlDeviceGetInforomImageVersion(dev)}', "PCI link": f'Gen.{pynvml.nvmlDeviceGetCurrPcieLinkGeneration(dev)} x{pynvml.nvmlDeviceGetCurrPcieLinkWidth(dev)}', "Power": f'{round(pynvml.nvmlDeviceGetPowerUsage(dev)/1000, 2)} W / {round(pynvml.nvmlDeviceGetEnforcedPowerLimit(dev)/1000, 2)} W', "GPU clock": f'{pynvml.nvmlDeviceGetClockInfo(dev, 0)} Mhz / {pynvml.nvmlDeviceGetMaxClockInfo(dev, 0)} Mhz', "SM clock": f'{pynvml.nvmlDeviceGetClockInfo(dev, 1)} Mhz / {pynvml.nvmlDeviceGetMaxClockInfo(dev, 1)} Mhz', "VRAM clock": f'{pynvml.nvmlDeviceGetClockInfo(dev, 2)} Mhz / {pynvml.nvmlDeviceGetMaxClockInfo(dev, 2)} Mhz', "VRAM usage": f'{round(100 * mem.used / mem.total)}% | {round(mem.used / 1024 / 1024)} MB used | {round(mem.free / 1024 / 1024)} MB free | {round(mem.total / 1024 / 1024)} MB total', "RAM usage": f'{round(100 * ram["used"] / ram["total"])}% | {round(1024 * ram["used"])} MB used | {round(1024 * ram["free"])} MB free | {round(1024 * ram["total"])} MB total', "System load": f'GPU {load.gpu}% | VRAM {load.memory}% | Temp {pynvml.nvmlDeviceGetTemperature(dev, 0)}C | Fan {pynvml.nvmlDeviceGetFanSpeed(dev)}%', 'State': get_reason(pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(dev)), } chart = [load.memory, load.gpu] devices.append({ 'name': name, 'data': data, 'chart': chart, }) # log.debug(f'nmvl: {devices}') return devices except Exception as e: warn_once(f'NVML: {e}') return [] if __name__ == '__main__': nvml_initialized = True import pynvml # pylint: disable=redefined-outer-name pynvml.nvmlInit() from rich import print as rprint for gpu in get_nvml(): rprint(gpu) ================================================ FILE: modules/api/process.py ================================================ from typing import Optional, List from threading import Lock from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from fastapi.responses import JSONResponse from fastapi.exceptions import HTTPException from modules.api.helpers import decode_base64_to_image, encode_pil_to_base64 from modules import errors, shared, postprocessing from modules.api import models, helpers processor = None # cached instance of processor errors.install() class ReqPreprocess(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") model: str = Field(title="Model", description="The model to use for preprocessing") params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings") class ResPreprocess(BaseModel): model: str = Field(default='', title="Model", description="The processor model used") image: str = Field(default='', title="Image", description="The processed image in base64 format") class ReqMask(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") type: str = Field(title="Mask type", description="Type of masking image to return") mask: Optional[str] = Field(title="Mask", description="If optional maks image is not provided auto-masking will be performed") model: Optional[str] = Field(title="Model", description="The model to use for preprocessing") params: Optional[dict] = Field(default={}, title="Settings", description="Preprocessor settings") class ReqFace(BaseModel): image: str = Field(title="Image", description="The base64 encoded image") model: Optional[str] = Field(title="Model", description="The model to use for detection") class ResFace(BaseModel): classes: List[int] = Field(title="Class", description="The class of detected item") labels: List[str] = Field(title="Label", description="The label of detected item") boxes: List[List[int]] = Field(title="Box", description="The bounding box of detected item") images: List[str] = Field(title="Image", description="The base64 encoded images of detected faces") scores: List[float] = Field(title="Scores", description="The scores of the detected faces") class ResMask(BaseModel): mask: str = Field(default='', title="Image", description="The processed image in base64 format") class ItemPreprocess(BaseModel): name: str = Field(title="Name") params: dict = Field(title="Params") class ItemMask(BaseModel): models: List[str] = Field(title="Models") colormaps: List[str] = Field(title="Color maps") params: dict = Field(title="Params") types: List[str] = Field(title="Types") class APIProcess(): def __init__(self, queue_lock: Lock): self.queue_lock = queue_lock def get_preprocess(self): from modules.control import processors items = [] for k, v in processors.config.items(): items.append(ItemPreprocess(name=k, params=v.get('params', {}))) return items def post_preprocess(self, req: ReqPreprocess): global processor # pylint: disable=global-statement from modules.control import processors processors_list = list(processors.config) if req.model not in processors_list: return JSONResponse(status_code=400, content={"error": f"Processor model not found: id={req.model}"}) image = decode_base64_to_image(req.image) if processor is None or processor.processor_id != req.model: with self.queue_lock: processor = processors.Processor(req.model) for k, v in req.params.items(): if k not in processors.config[processor.processor_id]['params']: return JSONResponse(status_code=400, content={"error": f"Processor invalid parameter: id={req.model} {k}={v}"}) jobid = shared.state.begin('API-PRE', api=True) processed = processor(image, local_config=req.params) image = encode_pil_to_base64(processed) shared.state.end(jobid) return ResPreprocess(model=processor.processor_id, image=image) def get_mask(self): from modules import masking return ItemMask(models=list(masking.MODELS), colormaps=masking.COLORMAP, params=vars(masking.opts), types=masking.TYPES) def post_mask(self, req: ReqMask): from modules import masking if req.model: if req.model not in masking.MODELS: return JSONResponse(status_code=400, content={"error": f"Mask model not found: id={req.model}"}) else: masking.init_model(req.model) if req.type not in masking.TYPES: return JSONResponse(status_code=400, content={"error": f"Mask type not found: id={req.type}"}) image = decode_base64_to_image(req.image) mask = decode_base64_to_image(req.mask) if req.mask else None for k, v in req.params.items(): if not hasattr(masking.opts, k): return JSONResponse(status_code=400, content={"error": f"Mask invalid parameter: {k}={v}"}) else: setattr(masking.opts, k, v) jobid = shared.state.begin('API-MASK', api=True) with self.queue_lock: processed = masking.run_mask(input_image=image, input_mask=mask, return_type=req.type) shared.state.end(jobid) if processed is None: return JSONResponse(status_code=400, content={"error": "Mask is none"}) image = encode_pil_to_base64(processed) return ResMask(mask=image) def post_detect(self, req: ReqFace): from modules.shared import yolo # pylint: disable=no-name-in-module image = decode_base64_to_image(req.image) jobid = shared.state.begin('API-FACE', api=True) images = [] scores = [] classes = [] boxes = [] labels = [] with self.queue_lock: items = yolo.predict(req.model, image) for item in items: images.append(encode_pil_to_base64(item.item)) scores.append(item.score) classes.append(item.cls) labels.append(item.label) boxes.append(item.box) shared.state.end(jobid) return ResFace(classes=classes, labels=labels, scores=scores, boxes=boxes, images=images) def post_prompt_enhance(self, req: models.ReqPromptEnhance): from modules import processing_helpers seed = req.seed or -1 seed = processing_helpers.get_fixed_seed(seed) prompt = '' if req.type == 'text': from modules.scripts_manager import scripts_txt2img model = 'google/gemma-3-1b-it' if req.model is None or len(req.model) < 4 else req.model instance = [s for s in scripts_txt2img.scripts if 'prompt_enhance.py' in s.filename][0] prompt = instance.enhance( model=model, prompt=req.prompt, system=req.system_prompt, seed=seed, nsfw=req.nsfw, ) elif req.type == 'image': from modules.scripts_manager import scripts_txt2img model = 'google/gemma-3-4b-it' if req.model is None or len(req.model) < 4 else req.model instance = [s for s in scripts_txt2img.scripts if 'prompt_enhance.py' in s.filename][0] prompt = instance.enhance( model=model, prompt=req.prompt, system=req.system_prompt, image=decode_base64_to_image(req.image), seed=seed, nsfw=req.nsfw, ) elif req.type == 'video': from modules.ui_video_vlm import enhance_prompt model = 'Google Gemma 3 4B' if req.model is None or len(req.model) < 4 else req.model prompt = enhance_prompt( enable=True, image=decode_base64_to_image(req.image), prompt=req.prompt, model=model, system_prompt=req.system_prompt, nsfw=req.nsfw, ) else: raise HTTPException(status_code=400, detail="prompt enhancement: invalid type") res = models.ResPromptEnhance(prompt=prompt, seed=seed) return res def set_upscalers(self, req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def extras_single_image_api(self, req: models.ReqProcessImage): reqDict = self.set_upscalers(req) reqDict['image'] = helpers.decode_base64_to_image(reqDict['image']) with self.queue_lock: result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return models.ResProcessImage(image=helpers.encode_pil_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: models.ReqProcessBatch): reqDict = self.set_upscalers(req) image_list = reqDict.pop('imageList', []) image_folder = [helpers.decode_base64_to_image(x.data) for x in image_list] with self.queue_lock: result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) return models.ResProcessBatch(images=list(map(helpers.encode_pil_to_base64, result[0])), html_info=result[1]) ================================================ FILE: modules/api/rocm_smi.py ================================================ import math import json import subprocess as sp from enum import IntFlag try: from installer import log except Exception: import logging log = logging.getLogger(__name__) try: from modules.rocm import version as rocm_version except Exception: rocm_version = "unknown" # ThrottleStatus is from leuc/amdgpu_metrics.py class ThrottleStatus(IntFlag): # linux/drivers/gpu/drm/amd/pm/inc/amdgpu_smu.h PPT0 = 1 << 0 PPT1 = 1 << 1 PPT2 = 1 << 2 PPT3 = 1 << 3 SPL = 1 << 4 FPPT = 1 << 5 SPPT = 1 << 6 SPPT_APU = 1 << 7 TDC_GFX = 1 << 16 TDC_SOC = 1 << 17 TDC_MEM = 1 << 18 TDC_VDD = 1 << 19 TDC_CVIP = 1 << 20 EDC_CPU = 1 << 21 EDC_GFX = 1 << 22 APCC = 1 << 23 TEMP_GPU = 1 << 32 TEMP_CORE = 1 << 33 TEMP_MEM = 1 << 34 TEMP_EDGE = 1 << 35 TEMP_HOTSPOT = 1 << 36 TEMP_SOC = 1 << 37 TEMP_VR_GFX = 1 << 38 TEMP_VR_SOC = 1 << 39 TEMP_VR_MEM0 = 1 << 40 TEMP_VR_MEM1 = 1 << 41 TEMP_LIQUID0 = 1 << 42 TEMP_LIQUID1 = 1 << 43 VRHOT0 = 1 << 44 VRHOT1 = 1 << 45 PROCHOT_CPU = 1 << 46 PROCHOT_GFX = 1 << 47 PPM = 1 << 56 FIT = 1 << 57 def active(self): members = self.__class__.__members__ return (m for m in members if getattr(self, m)._value_ & self.value != 0) # pylint: disable=protected-access def __iter__(self): return self.active() def __str__(self): return ', '.join(self.active()) def get_rocm_smi(): try: rocm_smi_data = json.loads(sp.check_output(("rocm-smi", "-a", "--json"))) driver_version = rocm_smi_data.pop("system", {"Driver version": "unknown"}).get("Driver version") devices = [] for key in rocm_smi_data.keys(): load = { 'gpu': rocm_smi_data[key].get('GPU use (%)', 'unknown'), 'memory': rocm_smi_data[key].get("GPU Memory Allocated (VRAM%)", "unknown"), 'temp': rocm_smi_data[key].get('Temperature (Sensor edge) (C)', 'unknown'), 'temp_junction': rocm_smi_data[key].get('Temperature (Sensor junction) (C)', 'unknown'), 'temp_memory': rocm_smi_data[key].get('Temperature (Sensor memory) (C)', 'unknown'), 'fan': rocm_smi_data[key].get('Fan speed (%)', 'unknown'), } data = { "ROCm": f'version {rocm_version} agent {rocm_smi_data[key].get("GFX Version", "unknown")}', "Driver": driver_version, "Hardware": f'VBIOS {rocm_smi_data[key].get("VBIOS version", "unknown")}', "PCI link": f'Gen.{int(math.log2(float(rocm_smi_data[key].get("pcie_link_speed (0.1 GT/s)", 10)) / 10))} x{rocm_smi_data[key].get("pcie_link_width (Lanes)", "unknown")}', "Power": f'{round(float(rocm_smi_data[key].get("Average Graphics Package Power (W)", 0)), 2)} W / {round(float(rocm_smi_data[key].get("Max Graphics Package Power (W)", 0)), 2)} W', "GPU clock": f'{rocm_smi_data[key].get("average_gfxclk_frequency (MHz)", 0)} Mhz / {rocm_smi_data[key].get("Valid sclk range", "0").split(" - ")[-1].removesuffix("Mhz")} Mhz', "VRAM clock": f'{rocm_smi_data[key].get("current_uclk (MHz)", 0)} Mhz / {rocm_smi_data[key].get("Valid mclk range", "0").split(" - ")[-1].removesuffix("Mhz")} Mhz', "VRAM usage": f'{load["memory"]}% Used | {rocm_smi_data[key].get("GPU Memory Read/Write Activity (%)", "unknown")}% Activity', "GPU usage": f'GPU {load["gpu"]}% | Fan {load["fan"]}%', "GPU temp": f'Edge {load["temp"]}C | Junction {load["temp_junction"]}C | Memory {load["temp_memory"]}C', 'Throttle reason': str(ThrottleStatus(int(rocm_smi_data[key].get("throttle_status", 0)))), } name = rocm_smi_data[key].get('Device Name', 'unknown') chart = [load["memory"], load["gpu"]] devices.append({ 'name': name, 'data': data, 'chart': chart, }) return devices except Exception as e: log.error(f'ROCm SMI: {e}') return [] if __name__ == '__main__': from rich import print as rprint for gpu in get_rocm_smi(): rprint(gpu) ================================================ FILE: modules/api/script.py ================================================ from typing import Optional from fastapi.exceptions import HTTPException import gradio as gr from modules.api import models from modules.errors import log from modules import scripts_manager def script_name_to_index(name, scripts_list): if name is None or len(name) == 0 or name == 'none': return None available = [script.title().lower() for script in scripts_list] if name.lower() in available: return available.index(name.lower()) short = [available.split(':')[0] for available in available] if name.lower() in short: return short.index(name.lower()) log.error(f'API: script={name} available={available} not found') return None def get_selectable_script(script_name, script_runner): if script_name is None or script_name == "" or script_name == 'none': return None, None script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) if script_idx is None: return None, None script = script_runner.selectable_scripts[script_idx] return script, script_idx def get_scripts_list(): t2ilist = [script.name for script in scripts_manager.scripts_txt2img.scripts if script.name is not None] i2ilist = [script.name for script in scripts_manager.scripts_img2img.scripts if script.name is not None] control = [script.name for script in scripts_manager.scripts_control.scripts if script.name is not None] return models.ResScripts(txt2img = t2ilist, img2img = i2ilist, control = control) def get_script_info(script_name: Optional[str] = None): res = [] for script_list in [scripts_manager.scripts_txt2img.scripts, scripts_manager.scripts_img2img.scripts, scripts_manager.scripts_control.scripts]: for script in script_list: if script.api_info is not None and (script_name is None or script_name == script.api_info.name): res.append(script.api_info) return res def get_script(script_name, script_runner): if script_name is None or script_name == "" or script_name == 'none': return None, None script_idx = script_name_to_index(script_name, script_runner.scripts) if script_idx is None: return None return script_runner.scripts[script_idx] def init_default_script_args(script_runner): # find max idx from the scripts in runner and generate a none array to init script_args last_arg_index = 1 for script in script_runner.scripts: if last_arg_index < script.args_to: # pylint disable=consider-using-max-builtin last_arg_index = script.args_to # None everywhere except position 0 to initialize script args script_args = [None]*last_arg_index script_args[0] = 0 # get default values if gr is None: return script_args with gr.Blocks(): # will throw errors calling ui function without this for script in script_runner.scripts: if script.ui(script.is_img2img): ui_default_values = [] for elem in script.ui(script.is_img2img): ui_default_values.append(elem.value) script_args[script.args_from:script.args_to] = ui_default_values return script_args def init_script_args(p, request, default_script_args, selectable_scripts, selectable_script_idx, script_runner): script_args = default_script_args.copy() # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() if selectable_scripts: for idx in range(len(request.script_args)): script_args[selectable_scripts.args_from + idx] = request.script_args[idx] script_args[0] = selectable_script_idx + 1 # Now check for always on scripts if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = get_script(alwayson_script_name, script_runner) if alwayson_script is None: raise HTTPException(status_code=422, detail=f"Always on script not found: {alwayson_script_name}") if not alwayson_script.alwayson: raise HTTPException(status_code=422, detail=f"Selectable script cannot be in always on params: {alwayson_script_name}") if "args" in request.alwayson_scripts[alwayson_script_name]: # min between arg length in scriptrunner and arg length in the request for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))): script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] p.per_script_args[alwayson_script.title()] = request.alwayson_scripts[alwayson_script_name]["args"] return script_args ================================================ FILE: modules/api/server.py ================================================ import os import time from typing import Any from fastapi import Request, Depends from fastapi.exceptions import HTTPException from fastapi.responses import FileResponse from modules import shared from modules.api import models, helpers def post_shutdown(): shared.log.info('Shutdown request received') import sys sys.exit(0) def get_js(request: Request): file = request.query_params.get("file", None) if (file is None) or (len(file) == 0): raise HTTPException(status_code=400, detail="file parameter is required") ext = file.split('.')[-1] if ext not in ['js', 'css', 'map', 'html', 'wasm', 'ttf', 'mjs', 'json']: raise HTTPException(status_code=400, detail=f"invalid file extension: {ext}") if not os.path.exists(file): shared.log.error(f"API: file not found: {file}") raise HTTPException(status_code=404, detail=f"file not found: {file}") if ext in ['js', 'mjs']: media_type = 'application/javascript' elif ext in ['map', 'json']: media_type = 'application/json' elif ext in ['css']: media_type = 'text/css' elif ext in ['html']: media_type = 'text/html' elif ext in ['wasm']: media_type = 'application/wasm' elif ext in ['ttf']: media_type = 'font/ttf' else: media_type = 'application/octet-stream' return FileResponse(file, media_type=media_type) def get_motd(): import requests motd = '' ver = shared.get_version() if ver.get('updated', None) is not None: motd = f"version {ver['commit']} {ver['updated']} {ver['url'].split('/')[-1]}
" # pylint: disable=use-maxsplit-arg if shared.opts.motd: try: res = requests.get('https://vladmandic.github.io/sdnext/motd', timeout=3) if res.status_code == 200: msg = (res.text or '').strip() shared.log.info(f'MOTD: {msg if len(msg) > 0 else "N/A"}') motd += res.text else: shared.log.error(f'MOTD: {res.status_code}') except Exception as err: shared.log.error(f'MOTD: {err}') return motd def get_version(): return shared.get_version() def get_platform(): from installer import get_platform as installer_get_platform from modules.loader import get_packages as loader_get_packages return { **installer_get_platform(), **loader_get_packages() } def get_log(req: models.ReqGetLog = Depends()): lines = shared.log.buffer[:req.lines] if req.lines > 0 else shared.log.buffer.copy() if req.clear: shared.log.buffer.clear() return lines def post_log(req: models.ReqPostLog): if req.message is not None: shared.log.info(f'UI: {req.message}') if req.debug is not None: shared.log.debug(f'UI: {req.debug}') if req.error is not None: shared.log.error(f'UI: {req.error}') return {} def get_config(): options = {} for k in shared.opts.data.keys(): if shared.opts.data_labels.get(k) is not None: options.update({k: shared.opts.data.get(k, shared.opts.data_labels.get(k).default)}) else: options.update({k: shared.opts.data.get(k, None)}) if 'sd_lyco' in options: del options['sd_lyco'] if 'sd_lora' in options: del options['sd_lora'] return options def set_config(req: dict[str, Any]): updated = [] for k, v in req.items(): updated.append({ k: shared.opts.set(k, v) }) shared.opts.save() return { "updated": updated } def get_cmd_flags(): return vars(shared.cmd_opts) def get_history(req: models.ReqHistory = Depends()): if req.id is not None and len(req.id) > 0: res = [item for item in shared.state.state_history if item['id'] == req.id] else: res = shared.state.state_history res = [models.ResHistory(**item) for item in res] return res def get_progress(req: models.ReqProgress = Depends()): if shared.state.job_count == 0: # idle state return models.ResProgress(id=shared.state.id, progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) shared.state.do_set_current_image() current_image = None if shared.state.current_image and not req.skip_current_image: current_image = helpers.encode_pil_to_base64(shared.state.current_image) batch_x = max(shared.state.job_no, 0) batch_y = max(shared.state.job_count, 1) step_x = max(shared.state.sampling_step, 0) prev_steps = max(shared.state.sampling_steps, 1) while step_x > shared.state.sampling_steps: shared.state.sampling_steps += prev_steps step_y = max(shared.state.sampling_steps, 1) current = step_y * batch_x + step_x total = step_y * batch_y progress = min((current / total) if current > 0 and total > 0 else 0, 1) time_since_start = time.time() - shared.state.time_start eta_relative = (time_since_start / progress) - time_since_start if progress > 0 else 0 # shared.log.critical(f'get_progress: batch {batch_x}/{batch_y} step {step_x}/{step_y} current {current}/{total} time={time_since_start} eta={eta_relative}') # shared.log.critical(shared.state) res = models.ResProgress(id=shared.state.id, progress=round(progress, 2), eta_relative=round(eta_relative, 2), current_image=current_image, textinfo=shared.state.textinfo, state=shared.state.dict(), ) return res def get_status(): return shared.state.status() def post_interrupt(): shared.state.interrupt() return {} def post_skip(): shared.state.skip() def get_memory(): try: import psutil process = psutil.Process(os.getpid()) res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total } except Exception as err: ram = { 'error': f'{err}' } try: import torch if torch.cuda.is_available(): s = torch.cuda.mem_get_info() system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] } s = dict(torch.cuda.memory_stats(shared.device)) allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] } reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] } active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] } inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] } warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } cuda = { 'system': system, 'active': active, 'allocated': allocated, 'reserved': reserved, 'inactive': inactive, 'events': warnings, } else: cuda = { 'error': 'unavailable' } except Exception as err: cuda = { 'error': f'{err}' } return models.ResMemory(ram = ram, cuda = cuda) ================================================ FILE: modules/api/xpu_smi.py ================================================ try: from installer import log except Exception: import logging log = logging.getLogger(__name__) def get_xpu_smi(): try: import torch from modules.memstats import ram_stats devices = [] mem = torch.xpu.memory_stats() ram = ram_stats() cap = torch.xpu.get_device_capability() prop = torch.xpu.get_device_properties() load = { 'gpu': 0, # no interface to get gpu load 'memory': mem['active_bytes.all.allocated'] // (1024**3), # no interface to get gpu memory so use torch instead } total = prop.total_memory // (1024**2) data = { 'Version': cap['version'], 'Driver': prop.driver_version, 'Platform': prop.platform_name, 'ID': hex(prop.device_id).removeprefix("0x"), 'Compute Units': prop.max_compute_units, "VRAM usage": f'{round(100 * load["memory"] / total)}% | {load["memory"]} MB used | {total - load["memory"]} MB free | {total} MB total', "RAM usage": f'{round(100 * ram["used"] / ram["total"])}% | {round(1024 * ram["used"])} MB used | {round(1024 * ram["free"])} MB free | {round(1024 * ram["total"])} MB total', } chart = [load["memory"], load["gpu"]] devices.append({ 'name': torch.xpu.get_device_name(), 'data': data, 'chart': chart, }) return devices except Exception as e: log.error(f'XPU SMI: {e}') return [] if __name__ == '__main__': from rich import print as rprint for gpu in get_xpu_smi(): rprint(gpu) ================================================ FILE: modules/api/xyz_grid.py ================================================ from typing import List def xyz_grid_enum(option: str = "") -> List[dict]: from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module options = [] for x in xyz_grid_classes.axis_options: _option = { 'label': x.label, 'type': x.type.__name__, 'cost': x.cost, 'choices': x.choices is not None, } if len(option) == 0: options.append(_option) else: if x.label.lower().startswith(option.lower()) or x.label.lower().endswith(option.lower()): if callable(x.choices): _option['choices'] = x.choices() options.append(_option) return options def register_api(): from modules.shared import api as api_instance api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict]) ================================================ FILE: modules/attention.py ================================================ from typing import Optional from functools import wraps import torch from modules import rocm from modules.errors import log from installer import install, installed def set_dynamic_attention(): try: sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention return sdpa_pre_dyanmic_atten except Exception as err: log.error(f'Torch attention: type="dynamic attention" {err}') return None def set_triton_flash_attention(backend: str): try: if backend in {"rocm", "zluda"}: # flash_attn_triton_amd only works with AMD from modules.flash_attn_triton_amd import interface_fa sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_triton_flash_atten) def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: if scale is None: scale = query.shape[-1] ** (-0.5) head_size_og = query.size(3) if head_size_og % 8 != 0: query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) out_padded = torch.zeros_like(query) interface_fa.fwd(query, key, value, out_padded, dropout_p, scale, is_causal) return out_padded[..., :head_size_og].transpose(1, 2) else: if enable_gqa: kwargs["enable_gqa"] = enable_gqa return sdpa_pre_triton_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) torch.nn.functional.scaled_dot_product_attention = sdpa_triton_flash_atten log.debug('Torch attention: type="Triton Flash attention"') except Exception as err: log.error(f'Torch attention: type="Triton Flash attention" {err}') def set_flex_attention(): try: from torch.nn.attention.flex_attention import flex_attention, create_block_mask def flex_attention_causal_mask(b, h, q_idx, kv_idx): # pylint: disable=unused-argument return q_idx >= kv_idx sdpa_pre_flex_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_flex_atten) def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument score_mod = None block_mask = None if attn_mask is not None: batch_size, num_heads = query.shape[:2] seq_len_q = query.shape[-2] seq_len_kv = key.shape[-2] if attn_mask.ndim == 2: attn_mask = attn_mask.view(attn_mask.shape[0], 1, attn_mask.size[1], 1) attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) if attn_mask.dtype == torch.bool: def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return attn_mask[batch_idx, head_idx, q_idx, kv_idx] block_mask = create_block_mask(mask_mod, batch_size, None, seq_len_q, seq_len_kv, device=query.device) else: def score_mod_fn(score, batch_idx, head_idx, q_idx, kv_idx): return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] score_mod = score_mod_fn elif is_causal: block_mask = create_block_mask(flex_attention_causal_mask, query.shape[0], query.shape[1], query.shape[-2], key.shape[-2], device=query.device) return flex_attention(query, key, value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa) torch.nn.functional.scaled_dot_product_attention = sdpa_flex_atten log.debug('Torch attention: type="Flex attention"') except Exception as err: log.error(f'Torch attention: type="Flex attention" {err}') def set_ck_flash_attention(backend: str, device: torch.device): try: if backend == "rocm": if not installed('flash-attn'): log.info('Torch attention: type="Flash attention" building...') agent = rocm.Agent(device) install(rocm.get_flash_attention_command(agent), reinstall=True) else: install('flash-attn') from flash_attn import flash_attn_func sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_flash_atten) def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: is_unsqueezed = False if query.dim() == 3: query = query.unsqueeze(0) is_unsqueezed = True if key.dim() == 3: key = key.unsqueeze(0) if value.dim() == 3: value = value.unsqueeze(0) if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) attn_output = flash_attn_func(q=query, k=key, v=value, dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2) if is_unsqueezed: attn_output = attn_output.squeeze(0) return attn_output else: if enable_gqa: kwargs["enable_gqa"] = enable_gqa return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten log.debug('Torch attention: type="Flash attention"') except Exception as err: log.error(f'Torch attention: type="Flash attention" {err}') def set_sage_attention(backend: str, device: torch.device): try: install('sageattention') use_cuda_backend = False if (backend == "cuda") and (torch.cuda.get_device_capability(device) == (8, 6)): use_cuda_backend = True # Detect GPU architecture - sm86 confirmed to need CUDA backend workaround as Sage Attention + Triton causes NaNs try: from sageattention import sageattn_qk_int8_pv_fp16_cuda except Exception: use_cuda_backend = False if use_cuda_backend: from sageattention import sageattn_qk_int8_pv_fp16_cuda def sage_attn_impl(query, key, value, is_causal, scale): return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, v=value, tensor_layout="HND", is_causal=is_causal, sm_scale=scale, return_lse=False, pv_accum_dtype="fp32", ) else: from sageattention import sageattn def sage_attn_impl(query, key, value, is_causal, scale): return sageattn( q=query, k=key, v=value, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=scale, ) sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention @wraps(sdpa_pre_sage_atten) def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32): if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) # Call pre-selected sage attention implementation return sage_attn_impl(query, key, value, is_causal, scale) else: if enable_gqa: kwargs["enable_gqa"] = enable_gqa return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten log.debug(f'Torch attention: type="Sage attention" backend={"cuda" if use_cuda_backend else "auto"}') except Exception as err: log.error(f'Torch attention: type="Sage attention" {err}') def set_diffusers_attention(pipe, quiet:bool=False): from modules import shared import diffusers.models.attention_processor as p def set_attn(pipe, attention, name:str=None): if attention is None: return # other models uses their own attention processor if getattr(pipe, "unet", None) is not None and hasattr(pipe.unet, "set_attn_processor"): try: pipe.unet.set_attn_processor(attention) except Exception as e: if 'Nunchaku' in pipe.unet.__class__.__name__: pass else: shared.log.error(f'Torch attention: type="{name}" cls={attention.__class__.__name__} pipe={pipe.__class__.__name__} {e}') """ # each transformer typically has its own attention processor if getattr(pipe, "transformer", None) is not None and hasattr(pipe.transformer, "set_attn_processor"): try: pipe.transformer.set_attn_processor(attention) except Exception as e: if 'Nunchaku' in pipe.transformer.__class__.__name__: pass else: shared.log.error(f'Torch attention: type="{name}" cls={attention.__class__.__name__} pipe={pipe.__class__.__name__} {e}') """ shared.log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"') if shared.opts.cross_attention_optimization == "Disabled": pass # do nothing elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers # set_attn(pipe, p.AttnProcessor2_0(), name="Scaled-Dot-Product") pass elif shared.opts.cross_attention_optimization == "xFormers": if hasattr(pipe, 'enable_xformers_memory_efficient_attention'): pipe.enable_xformers_memory_efficient_attention() else: shared.log.warning(f"Attention: xFormers is not compatible with {pipe.__class__.__name__}") elif shared.opts.cross_attention_optimization == "Batch matrix-matrix": set_attn(pipe, p.AttnProcessor(), name="Batch matrix-matrix") elif shared.opts.cross_attention_optimization == "Dynamic Attention BMM": from modules.sd_hijack_dynamic_atten import DynamicAttnProcessorBMM set_attn(pipe, DynamicAttnProcessorBMM(), name="Dynamic Attention BMM") if shared.opts.attention_slicing != "Default" and hasattr(pipe, "enable_attention_slicing") and hasattr(pipe, "disable_attention_slicing"): if shared.opts.attention_slicing: pipe.enable_attention_slicing() else: pipe.disable_attention_slicing() shared.log.debug(f"Torch attention: slicing={shared.opts.attention_slicing}") pipe.current_attn_name = shared.opts.cross_attention_optimization ================================================ FILE: modules/ben2/__init__.py ================================================ model = None def remove(image, refine: bool = True): global model # pylint: disable=global-statement from modules import shared, devices if model is None: from huggingface_hub import hf_hub_download from .ben2_model import BEN_Base model = BEN_Base() model_file = hf_hub_download( repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', cache_dir=shared.opts.hfcache_dir) model.loadcheckpoints(model_file) model = model.to(device=devices.device, dtype=devices.dtype).eval() model = model.to(device=devices.device) foreground = model.inference(image, refine_foreground=refine) model = model.to(device=devices.cpu) if foreground is None: return image return foreground ================================================ FILE: modules/ben2/ben2_model.py ================================================ import os import math import subprocess import tempfile import cv2 import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import numpy as np from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from PIL import Image from torchvision import transforms from einops import rearrange class Mlp(nn.Module): """ Multilayer perceptron.""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class WindowAttention(nn.Module): """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Forward function. Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): """ Swin Transformer Block. Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.H = None self.W = None def forward(self, x, mask_matrix): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. mask_matrix: Attention mask for cyclic shift. """ B, L, C = x.shape H, W = self.H, self.W assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # pad feature maps to multiples of window size pad_l = pad_t = 0 pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) attn_mask = mask_matrix else: shifted_x = x attn_mask = None # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): """ Patch Merging Layer Args: dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) # padding pad_input = (H % 2 == 1) or (W % 2 == 1) if pad_input: x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of feature channels depth (int): Depths of this stage. num_heads (int): Number of attention head. window_size (int): Local window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, dim, depth, num_heads, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() self.window_size = window_size self.shift_size = window_size // 2 self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock( dim=dim, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x, H, W): """ Forward function. Args: x: Input feature, tensor size (B, H*W, C). H, W: Spatial resolution of the input feature. """ # calculate attention mask for SW-MSA Hp = int(np.ceil(H / self.window_size)) * self.window_size Wp = int(np.ceil(W / self.window_size)) * self.window_size img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) for blk in self.blocks: blk.H, blk.W = H, W if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask) else: x = blk(x, attn_mask) if self.downsample is not None: x_down = self.downsample(x, H, W) Wh, Ww = (H + 1) // 2, (W + 1) // 2 return x, H, W, x_down, Wh, Ww else: return x, H, W, x, H, W class PatchEmbed(nn.Module): """ Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, H, W = x.size() if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) if H % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) # B C Wh Ww if self.norm is not None: Wh, Ww = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) return x class SwinTransformer(nn.Module): """ Swin Transformer backbone. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - https://arxiv.org/pdf/2103.14030 Args: pretrain_img_size (int): Input image size for training the pretrained model, used in absolute postion embedding. Default 224. patch_size (int | tuple(int)): Patch size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. depths (tuple[int]): Depths of each Swin Transformer stage. num_heads (tuple[int]): Number of attention head of each stage. window_size (int): Window size. Default: 7. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. drop_rate (float): Dropout rate. attn_drop_rate (float): Attention dropout rate. Default: 0. drop_path_rate (float): Stochastic depth rate. Default: 0.2. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. patch_norm (bool): If True, add normalization after patch embedding. Default: True. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, pretrain_img_size=224, patch_size=4, in_chans=3, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, use_checkpoint=False): super().__init__() self.pretrain_img_size = pretrain_img_size self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.out_indices = out_indices self.frozen_stages = frozen_stages # split image into non-overlapping patches self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) # absolute position embedding if self.ape: pretrain_img_size = to_2tuple(pretrain_img_size) patch_size = to_2tuple(patch_size) patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] self.num_features = num_features # add a norm layer for each output for i_layer in out_indices: layer = norm_layer(num_features[i_layer]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) self._freeze_stages() def _freeze_stages(self): if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): param.requires_grad = False if self.frozen_stages >= 1 and self.ape: self.absolute_pos_embed.requires_grad = False if self.frozen_stages >= 2: self.pos_drop.eval() for i in range(0, self.frozen_stages - 1): m = self.layers[i] m.eval() for param in m.parameters(): param.requires_grad = False def forward(self, x): x = self.patch_embed(x) Wh, Ww = x.size(2), x.size(3) if self.ape: # interpolate the position embedding to the corresponding size absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') x = x + absolute_pos_embed # B Wh*Ww C outs = [x.contiguous()] x = x.flatten(2).transpose(1, 2) x = self.pos_drop(x) for i in range(self.num_layers): layer = self.layers[i] x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) if i in self.out_indices: norm_layer = getattr(self, f'norm{i}') x_out = norm_layer(x_out) out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() outs.append(out) return tuple(outs) def get_activation_fn(activation): """Return an activation function given a string""" if activation == "gelu": return F.gelu raise RuntimeError(F"activation should be gelu, not {activation}.") def make_cbr(in_dim, out_dim): return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU()) def make_cbg(in_dim, out_dim): return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU()) def rescale_to(x, scale_factor: float = 2, interpolation='nearest'): return F.interpolate(x, scale_factor=scale_factor, mode=interpolation) def resize_as(x, y, interpolation='bilinear'): return F.interpolate(x, size=y.shape[-2:], mode=interpolation) def image2patches(x): """b c (hg h) (wg w) -> (hg wg b) c h w""" x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2 ) return x def patches2image(x): """(hg wg b) c h w -> b c (hg h) (wg w)""" x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) return x class PositionEmbeddingSine: def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32) def __call__(self, b, h, w): device = self.dim_t.device mask = torch.zeros([b, h, w], dtype=torch.bool, device=device) assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(dim=1, dtype=torch.float32) x_embed = not_mask.cumsum(dim=2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) class MCLM(nn.Module): def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]): super(MCLM, self).__init__() self.attention = nn.ModuleList([ nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1) ]) self.linear1 = nn.Linear(d_model, d_model * 2) self.linear2 = nn.Linear(d_model * 2, d_model) self.linear3 = nn.Linear(d_model, d_model * 2) self.linear4 = nn.Linear(d_model * 2, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.1) self.dropout1 = nn.Dropout(0.1) self.dropout2 = nn.Dropout(0.1) self.activation = get_activation_fn('gelu') self.pool_ratios = pool_ratios self.p_poses = [] self.g_pos = None self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True) def forward(self, l, g): """ l: 4,c,h,w g: 1,c,h,w """ self.p_poses = [] self.g_pos = None _b, _c, h, w = l.size() # 4,c,h,w -> 1,c,2h,2w concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) pools = [] for pool_ratio in self.pool_ratios: # b,c,h,w tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw) pools.append(rearrange(pool, 'b c h w -> (h w) b c')) if self.g_pos is None: pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3]) pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c') self.p_poses.append(pos_emb) pools = torch.cat(pools, 0) if self.g_pos is None: self.p_poses = torch.cat(self.p_poses, dim=0) pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3]) self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c') device = pools.device self.p_poses = self.p_poses.to(device) self.g_pos = self.g_pos.to(device) # attention between glb (q) & multisensory concated-locs (k,v) g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c') g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0]) g_hw_b_c = self.norm1(g_hw_b_c) g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone()))) g_hw_b_c = self.norm2(g_hw_b_c) # attention between origin locs (q) & freashed glb (k,v) l_hw_b_c = rearrange(l, "b c h w -> (h w) b c") _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w) _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2) outputs_re = [] for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))): outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re) l_hw_b_c = self.norm1(l_hw_b_c) l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone()))) l_hw_b_c = self.norm2(l_hw_b_c) l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w) class MCRM(nn.Module): def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): # pylint: disable=unused-argument super(MCRM, self).__init__() self.attention = nn.ModuleList([ nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1), nn.MultiheadAttention(d_model, num_heads, dropout=0.1) ]) self.linear3 = nn.Linear(d_model, d_model * 2) self.linear4 = nn.Linear(d_model * 2, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(0.1) self.dropout1 = nn.Dropout(0.1) self.dropout2 = nn.Dropout(0.1) self.sigmoid = nn.Sigmoid() self.activation = get_activation_fn('gelu') self.sal_conv = nn.Conv2d(d_model, 1, 1) self.pool_ratios = pool_ratios def forward(self, x): # device = x.device _b, c, h, w = x.size() loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) token_attention_map = self.sigmoid(self.sal_conv(glb)) token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest') loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) pools = [] for pool_ratio in self.pool_ratios: tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw) pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c") loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c') outputs = [] for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches v = pools[i] k = v outputs.append(self.attention[i](q, k, v)[0]) outputs = torch.cat(outputs, 1) src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs) src = self.norm1(src) src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone()))) src = self.norm2(src) src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb return torch.cat((src, glb), 0), token_attention_map class BEN_Base(nn.Module): def __init__(self): super().__init__() self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12) emb_dim = 128 self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) self.output5 = make_cbr(1024, emb_dim) self.output4 = make_cbr(512, emb_dim) self.output3 = make_cbr(256, emb_dim) self.output2 = make_cbr(128, emb_dim) self.output1 = make_cbr(128, emb_dim) self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8]) self.conv1 = make_cbr(emb_dim, emb_dim) self.conv2 = make_cbr(emb_dim, emb_dim) self.conv3 = make_cbr(emb_dim, emb_dim) self.conv4 = make_cbr(emb_dim, emb_dim) self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8]) self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8]) self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8]) self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8]) self.insmask_head = nn.Sequential( nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1), nn.InstanceNorm2d(384), nn.GELU(), nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.InstanceNorm2d(384), nn.GELU(), nn.Conv2d(384, emb_dim, kernel_size=3, padding=1) ) self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1)) self.upsample1 = make_cbg(emb_dim, emb_dim) self.upsample2 = make_cbg(emb_dim, emb_dim) self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) for m in self.modules(): if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout): m.inplace = True @torch.inference_mode() @torch.autocast(device_type="cuda",dtype=torch.float16) def forward(self, x): real_batch = x.size(0) shallow_batch = self.shallow(x) glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear') final_input = None for i in range(real_batch): start = i * 4 end = (i + 1) * 4 loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0)) input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0) if final_input is None: final_input= input_ else: final_input = torch.cat((final_input, input_), dim=0) features = self.backbone(final_input) outputs = [] for i in range(real_batch): start = i * 5 end = (i + 1) * 5 f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W] f3 = features[3][start:end, :, :, :] f2 = features[2][start:end, :, :, :] f1 = features[1][start:end, :, :, :] f0 = features[0][start:end, :, :, :] e5 = self.output5(f4) e4 = self.output4(f3) e3 = self.output3(f2) e2 = self.output2(f1) e1 = self.output1(f0) loc_e5, glb_e5 = e5.split([4, 1], dim=0) e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16) e4, _tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4)) e4 = self.conv4(e4) e3, _tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3)) e3 = self.conv3(e3) e2, _tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2)) e2 = self.conv2(e2) e1, _tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1)) e1 = self.conv1(e1) loc_e1, glb_e1 = e1.split([4, 1], dim=0) output1_cat = patches2image(loc_e1) # (1,128,256,256) # add glb feat in output1_cat = output1_cat + resize_as(glb_e1, output1_cat) # merge final_output = self.insmask_head(output1_cat) # (1,128,256,256) # shallow feature merge shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0) final_output = final_output + resize_as(shallow, final_output) final_output = self.upsample1(rescale_to(final_output)) final_output = rescale_to(final_output + resize_as(shallow, final_output)) final_output = self.upsample2(final_output) final_output = self.output(final_output) mask = final_output.sigmoid() outputs.append(mask) return torch.cat(outputs, dim=0) def loadcheckpoints(self,model_path): model_dict = torch.load(model_path, map_location="cpu", weights_only=True) self.load_state_dict(model_dict['model_state_dict'], strict=True) del model_path def inference(self,image,refine_foreground=False): # image = ImageOps.exif_transpose(image) if isinstance(image, Image.Image): image, h, w,original_image = rgb_loader_refiner(image) if torch.cuda.is_available(): img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device) else: img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device) with torch.no_grad(): res = self.forward(img_tensor) # Show Results if refine_foreground: pred_pil = transforms.ToPILImage()(res.squeeze()) image_masked = refine_foreground_process(original_image, pred_pil) image_masked.putalpha(pred_pil.resize(original_image.size)) return image_masked else: alpha = postprocess_image(res, im_size=[w,h]) pred_pil = transforms.ToPILImage()(alpha) mask = pred_pil.resize(original_image.size) original_image.putalpha(mask) # mask = Image.fromarray(alpha) return original_image else: foregrounds = [] for batch in image: image, h, w,original_image = rgb_loader_refiner(batch) if torch.cuda.is_available(): img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device) else: img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device) with torch.no_grad(): res = self.forward(img_tensor) if refine_foreground: pred_pil = transforms.ToPILImage()(res.squeeze()) image_masked = refine_foreground_process(original_image, pred_pil) image_masked.putalpha(pred_pil.resize(original_image.size)) foregrounds.append(image_masked) else: alpha = postprocess_image(res, im_size=[w,h]) pred_pil = transforms.ToPILImage()(alpha) mask = pred_pil.resize(original_image.size) original_image.putalpha(mask) # mask = Image.fromarray(alpha) foregrounds.append(original_image) return foregrounds def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)): """ Segments the given video to extract the foreground (with alpha) from each frame and saves the result as either a WebM video (with alpha channel) or MP4 (with a color background). Args: video_path (str): Path to the input video file. output_path (str, optional): Directory (or full path) where the output video and/or files will be saved. Defaults to "./". fps (int, optional): The frames per second (FPS) to use for the output video. If 0 (default), the original FPS of the input video is used. Otherwise, overrides it. refine_foreground (bool, optional): Whether to run an additional “refine foreground” process on each frame. Defaults to False. batch (int, optional): Number of frames to process at once (inference batch size). Large batch sizes may require more GPU memory. Defaults to 1. print_frames_processed (bool, optional): If True (default), prints progress (how many frames have been processed) to the console. webm (bool, optional): If True (default), exports a WebM video with alpha channel (VP9 / yuva420p). If False, exports an MP4 video composited over a solid color background. rgb_value (tuple, optional): The RGB background color (e.g., green screen) used to composite frames when saving to MP4. Defaults to (0, 255, 0). Returns: None. Writes the output video(s) to disk in the specified format. """ cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise IOError(f"Cannot open video: {video_path}") original_fps = cap.get(cv2.CAP_PROPFPS) original_fps = 30 if original_fps == 0 else original_fps fps = original_fps if fps == 0 else fps ret, first_frame = cap.read() if not ret: raise ValueError("No frames found in the video.") _height, _width = first_frame.shape[:2] cap.set(cv2.CAP_PROP_POSFRAMES, 0) foregrounds = [] frame_idx = 0 processed_count = 0 batch_frames = [] total_frames = int(cap.get(cv2.CAP_PROPFRAME_COUNT)) while True: ret, frame = cap.read() if not ret: if batch_frames: batch_results = self.inference(batch_frames, refine_foreground) if isinstance(batch_results, Image.Image): foregrounds.append(batch_results) else: foregrounds.extend(batch_results) if print_frames_processed: print(f"Processed frames {frame_idx-len(batch_frames)+1} to {frame_idx} of {total_frames}") break # Process every frame instead of using intervals frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_frame = Image.fromarray(frame_rgb) batch_frames.append(pil_frame) if len(batch_frames) == batch: batch_results = self.inference(batch_frames, refine_foreground) if isinstance(batch_results, Image.Image): foregrounds.append(batch_results) else: foregrounds.extend(batch_results) if print_frames_processed: print(f"Processed frames {frame_idx-batch+1} to {frame_idx} of {total_frames}") batch_frames = [] processed_count += batch frame_idx += 1 if webm: alpha_webm_path = os.path.join(output_path, "foreground.webm") pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps) else: cap.release() fg_output = os.path.join(output_path, 'foreground.mp4') pil_images_to_mp4(foregrounds, fg_output, fps=original_fps,rgb_value=rgb_value) cv2.destroyAllWindows() try: fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4') add_audio_to_video(fg_output, video_path, fg_audio_output) except Exception as e: print("No audio found in the original video") print(e) def rgb_loader_refiner( original_image): h, w = original_image.size image = original_image # Convert to RGB if necessary if image.mode != 'RGB': image = image.convert('RGB') # Resize the image image = image.resize((1024, 1024), resample=Image.Resampling.LANCZOS) return image.convert('RGB'), h, w,original_image # Define the image transformation img_transform = transforms.Compose([ transforms.ToTensor(), transforms.ConvertImageDtype(torch.float16), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img_transform32 = transforms.Compose([ transforms.ToTensor(), transforms.ConvertImageDtype(torch.float32), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)): """ Converts an array of PIL images to an MP4 video. Args: images: List of PIL images output_path: Path to save the MP4 file fps: Frames per second (default: 24) rgb_value: Background RGB color tuple (default: green (0, 255, 0)) """ if not images: raise ValueError("No images provided to convert to MP4.") width, height = images[0].size fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for image in images: # If image has alpha channel, composite onto the specified background color if image.mode == 'RGBA': # Create background image with specified RGB color background = Image.new('RGB', image.size, rgb_value) background = background.convert('RGBA') # Composite the image onto the background image = Image.alpha_composite(background, image) image = image.convert('RGB') else: # Ensure RGB format for non-alpha images image = image.convert('RGB') # Convert to OpenCV format and write open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) video_writer.write(open_cv_image) video_writer.release() def pil_images_to_webm_alpha(images, output_path, fps=30): """ Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel. NOTE: Not all players will display alpha in WebM. Browsers like Chrome/Firefox typically do support VP9 alpha. """ if not images: raise ValueError("No images provided for WebM with alpha.") # Ensure output directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: # Save frames as PNG (with alpha) for idx, img in enumerate(images): if img.mode != "RGBA": img = img.convert("RGBA") out_path = os.path.join(tmpdir, f"{idx:06d}.png") img.save(out_path, "PNG") # Construct ffmpeg command # -c:v libvpx-vp9 => VP9 encoder # -pix_fmt yuva420p => alpha-enabled pixel format # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk) ffmpeg_cmd = [ "ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(tmpdir, "%06d.png"), "-c:v", "libvpx-vp9", "-pix_fmt", "yuva420p", "-auto-alt-ref", "0", output_path ] subprocess.run(ffmpeg_cmd, check=True) print(f"WebM with alpha saved to {output_path}") def add_audio_to_video(video_without_audio_path, original_video_path, output_path): """ Check if the original video has an audio stream. If yes, add it. If not, skip. """ # 1) Probe original video for audio streams probe_command = [ 'ffprobe', '-v', 'error', '-select_streams', 'a:0', '-show_entries', 'stream=index', '-of', 'csv=p=0', original_video_path ] result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False) # result.stdout is empty if no audio stream found if not result.stdout.strip(): print("No audio track found in original video, skipping audio addition.") return print("Audio track detected; proceeding to mux audio.") # 2) If audio found, run ffmpeg to add it command = [ 'ffmpeg', '-y', '-i', video_without_audio_path, '-i', original_video_path, '-c', 'copy', '-map', '0:v:0', '-map', '1:a:0', # we know there's an audio track now output_path ] subprocess.run(command, check=True) print(f"Audio added successfully => {output_path}") ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py def refine_foreground_process(image, mask, r=90): if mask.size != image.size: mask = mask.resize(image.size) image = np.array(image) / 255.0 mask = np.array(mask) / 255.0 estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r) image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) return image_masked def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation alpha = alpha[:, :, None] F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r) return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): if isinstance(image, Image.Image): image = np.array(image) / 255.0 blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] blurredFA = cv2.blur(F * alpha, (r, r)) blurredF = blurredFA / (blurred_alpha + 1e-5) blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) F = blurredF + alpha * \ (image - alpha * blurredF - (1 - alpha) * blurred_B) F = np.clip(F, 0, 1) return F, blurred_B def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray: result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) im_array = np.squeeze(im_array) return im_array ================================================ FILE: modules/cachedit.py ================================================ import os from installer import install from modules import shared def apply_cache_dit(pipe): if not shared.opts.cache_dit_enabled: return install('git+https://github.com/vipshop/cache-dit', 'cache_dit') os.environ.setdefault("CACHE_DIT_LOG_LEVEL", "error") try: import cache_dit except Exception as e: shared.log.error(f'Cache-DIT: {e}') return _, supported = cache_dit.supported_pipelines() supported = [s.replace('*', '') for s in supported] if not any(pipe.__class__.__name__.startswith(s) for s in supported): shared.log.error(f'Cache-DiT: pipeline={pipe.__class__.__name__} unsupported') return if getattr(pipe, 'has_cache_dit', False): unapply_cache_dir(pipe) config_args = {} if shared.opts.cache_dit_fcompute >= 0: config_args['Fn_compute_blocks'] = int(shared.opts.cache_dit_fcompute) if shared.opts.cache_dit_bcompute >= 0: config_args['Bn_compute_blocks'] = int(shared.opts.cache_dit_bcompute) if shared.opts.cache_dit_threshold >= 0: config_args['residual_diff_threshold'] = float(shared.opts.cache_dit_threshold) if shared.opts.cache_dit_warmup >= 0: config_args['max_warmup_steps'] = int(shared.opts.cache_dit_warmup) cache_config = cache_dit.BasicCacheConfig(**config_args) if shared.opts.cache_dit_calibrator == "TaylorSeer": calibrator_config = cache_dit.TaylorSeerCalibratorConfig(taylorseer_order=1) elif shared.opts.cache_dit_calibrator == "FoCa": calibrator_config = cache_dit.FoCaCalibratorConfig() else: calibrator_config = None shared.log.info(f'Apply Cache-DiT: config="{cache_config.strify()}" calibrator="{calibrator_config.strify() if calibrator_config else "None"}"') try: cache_dit.enable_cache( pipe, cache_config=cache_config, calibrator_config=calibrator_config, ) shared.sd_model.has_cache_dit = True except Exception as e: shared.log.error(f'Cache-DiT: {e}') return def unapply_cache_dir(pipe): if not shared.opts.cache_dit_enabled or not getattr(pipe, 'has_cache_dit', False): return try: import cache_dit # stats = cache_dit.summary(pipe) # shared.log.critical(f'Unapply Cache-DiT: {stats}') cache_dit.disable_cache(pipe) pipe.has_cache_dit = False except Exception: return ================================================ FILE: modules/call_queue.py ================================================ import os import sys import html import threading import time import cProfile from modules import shared, progress, errors, timer queue_lock = threading.Lock() debug = os.environ.get('SD_QUEUE_DEBUG', None) is not None def get_lock(): if debug: fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access errors.log.debug(f'Queue: fn={fn} lock={queue_lock.locked()}') return queue_lock def wrap_queued_call(func): def f(*args, **kwargs): with get_lock(): res = func(*args, **kwargs) return res return f def wrap_gradio_gpu_call(func, extra_outputs=None, name=None): name = name or func.__name__ def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": id_task = args[0] progress.add_task_to_queue(id_task) else: id_task = None with get_lock(): progress.start_task(id_task) try: res = func(*args, **kwargs) progress.record_results(id_task, res) except Exception as e: shared.log.error(f"Exception: {e}") shared.log.error(f"Arguments: args={str(args)[:10240]} kwargs={str(kwargs)[:10240]}") errors.display(e, 'gradio call') res = extra_outputs or [] res.append(f"
{html.escape(str(e))}
") finally: progress.finish_task(id_task) return res return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True, name=name) def wrap_gradio_call(func, extra_outputs=None, add_stats=False, name=None): job_name = name if name is not None else func.__name__ def f(*args, extra_outputs_array=extra_outputs, **kwargs): t = time.perf_counter() shared.mem_mon.reset() if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": task_id = args[0] else: task_id = 0 jobid = shared.state.begin(job_name, task_id=task_id) try: if shared.cmd_opts.profile: pr = cProfile.Profile() pr.enable() res = func(*args, **kwargs) if res is None: msg = "No result returned from function" shared.log.warning(msg) res = extra_outputs_array or [] res.append(f"
{html.escape(msg)}
") else: res = list(res) if shared.cmd_opts.profile: pr.disable() errors.profile(pr, 'Wrap') except Exception as e: errors.display(e, 'gradio call') res = extra_outputs_array or [] res.append(f"
{html.escape(type(e).__name__+': '+str(e))}
") shared.state.end(jobid) if not add_stats: return tuple(res) elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_m}m {elapsed_s:.2f}s" if elapsed_m > 0 else f"{elapsed_s:.2f}s" summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ') memory = shared.mem_mon.summary() if isinstance(res, list) and isinstance(res[-1], str): res[-1] += f"

Time: {elapsed_text} | {summary} {memory}

" return tuple(res) return f ================================================ FILE: modules/cfgzero/__init__.py ================================================ # reference: from modules import shared, processing, sd_models orig_pipeline = None supported = [ 'FluxPipeline', 'CogView4Pipeline', 'StableDiffusion3Pipeline', 'HiDreamImagePipeline', 'WanPipeline', 'HunyuanVideoPipeline', ] def apply(p: processing.StableDiffusionProcessing): if not shared.opts.cfgzero_enabled: return None cls = shared.sd_model.__class__.__name__ if shared.sd_loaded else 'None' if 'CFGZero' in cls: unapply() if cls not in supported: return None global orig_pipeline # pylint: disable=global-statement orig_pipeline = shared.sd_model if cls == 'FluxPipeline': from diffusers import pipelines from modules.cfgzero.flux_pipeline import FluxCFGZeroPipeline shared.sd_model = sd_models.switch_pipe(FluxCFGZeroPipeline, shared.sd_model) pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["fluxcfgzero"] = FluxCFGZeroPipeline pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["fluxcfgzero"] = pipelines.FluxImg2ImgPipeline pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["fluxcfgzero"] = pipelines.FluxInpaintPipeline if cls == 'CogView4Pipeline': from modules.cfgzero.cogview4_pipeline import CogView4CFGZeroPipeline shared.sd_model = sd_models.switch_pipe(CogView4CFGZeroPipeline, shared.sd_model) if cls == 'StableDiffusion3Pipeline': from modules.cfgzero.sd3_pipeline import StableDiffusion3CFGZeroPipeline shared.sd_model = sd_models.switch_pipe(StableDiffusion3CFGZeroPipeline, shared.sd_model) if cls == 'HiDreamImagePipeline': from modules.cfgzero.hidream_pipeline import HiDreamImageCFGZeroPipeline shared.sd_model = sd_models.switch_pipe(HiDreamImageCFGZeroPipeline, shared.sd_model) if cls == 'WanPipeline': from modules.cfgzero.wan_t2v_pipeline import WanCFGZeroPipeline shared.sd_model = sd_models.switch_pipe(WanCFGZeroPipeline, shared.sd_model) if cls == 'HunyuanVideoPipeline': from modules.cfgzero.hunyuan_t2v_pipeline import HunyuanVideoCFGZeroPipeline shared.sd_model = sd_models.switch_pipe(HunyuanVideoCFGZeroPipeline, shared.sd_model) shared.log.debug(f'Apply CFGZero: cls={cls} init={shared.opts.cfgzero_enabled} star={shared.opts.cfgzero_star} steps={shared.opts.cfgzero_steps}') p.task_args['use_zero_init'] = shared.opts.cfgzero_enabled p.task_args['use_cfg_zero_star'] = shared.opts.cfgzero_star p.task_args['zero_steps'] = int(shared.opts.cfgzero_steps) p.extra_generation_params['CFGZero'] = True def unapply(): global orig_pipeline # pylint: disable=global-statement if orig_pipeline is not None: shared.sd_model = orig_pipeline orig_pipeline = None return shared.sd_model.__class__ ================================================ FILE: modules/cfgzero/cogview4_pipeline.py ================================================ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers import AutoTokenizer, GlmModel from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import CogView4LoraLoaderMixin from diffusers.models import AutoencoderKL, CogView4Transformer2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.cogview4.pipeline_output import CogView4PipelineOutput @torch.cuda.amp.autocast(dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers import CogView4Pipeline >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] >>> image.save("output.png") ``` """ def calculate_shift( image_seq_len, base_seq_len: int = 256, base_shift: float = 0.25, max_shift: float = 0.75, ) -> float: m = (image_seq_len / base_seq_len) ** 0.5 mu = m * max_shift + base_shift return mu def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if timesteps is not None and sigmas is not None: if not accepts_timesteps and not accepts_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep or sigma schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif timesteps is not None and sigmas is None: if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif timesteps is None and sigmas is not None: if not accepts_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class CogView4CFGZeroPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): r""" Pipeline for text-to-image generation using CogView4. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`GLMModel`]): Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). tokenizer (`PreTrainedTokenizer`): Tokenizer of class [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). transformer ([`CogView4Transformer2DModel`]): A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: GlmModel, vae: AutoencoderKL, transformer: CogView4Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def _get_glm_embeds( self, prompt: Union[str, List[str]] = None, max_sequence_length: int = 1024, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = self.tokenizer( prompt, padding="longest", # not use max length max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) current_length = text_input_ids.shape[1] pad_length = (16 - (current_length % 16)) % 16 if pad_length > 0: pad_ids = torch.full( (text_input_ids.shape[0], pad_length), fill_value=self.tokenizer.pad_token_id, dtype=text_input_ids.dtype, device=text_input_ids.device, ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) prompt_embeds = self.text_encoder( text_input_ids.to(self.text_encoder.device), output_hidden_states=True ).hidden_states[-2] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 1024, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): Whether to use classifier free guidance or not. num_images_per_prompt (`int`, *optional*, defaults to 1): Number of images that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. device: (`torch.device`, *optional*): torch device dtype: (`torch.dtype`, *optional*): torch dtype max_sequence_length (`int`, defaults to `1024`): Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. """ device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) seq_len = prompt_embeds.size(1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) seq_len = negative_prompt_embeds.size(1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): if latents is not None: return latents.to(device) shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents def check_inputs( self, prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def attention_kwargs(self): return self._attention_kwargs @property def current_timestep(self): return self._current_timestep @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 1024, use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, zero_steps: Optional[int] = 0, ) -> Union[CogView4PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. If not provided, it is set to 1024. width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. If not provided it is set to 1024. num_inference_steps (`int`, *optional*, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to `5.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to `1`): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, defaults to `224`): Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. Examples: Returns: [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = (height, width) # Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False # Default call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, negative_prompt, self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, ) # Prepare latents latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, height, width, torch.float32, device, generator, latents, ) # Prepare additional timestep conditions original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) # Prepare timesteps image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( self.transformer.config.patch_size**2 ) timesteps = ( np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) if timesteps is None else np.array(timesteps) ) timesteps = timesteps.astype(np.int64).astype(np.float32) sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_shift", 0.25), self.scheduler.config.get("max_shift", 0.75), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu ) self._num_timesteps = len(timesteps) # Denoising loop transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) noise_pred_cond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, return_dict=False, )[0] if use_cfg_zero_star: positive_flat = noise_pred_cond.view(batch_size, -1) negative_flat = noise_pred_uncond.view(batch_size, -1) alpha = optimized_scale(positive_flat,negative_flat) alpha = alpha.view(batch_size, *([1] * (len(noise_pred_cond.shape) - 1))) alpha = alpha.to(positive_flat.dtype) if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred_cond*0. else: noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_cond - noise_pred_uncond * alpha) else: noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() self._current_timestep = None if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False, generator=generator)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return CogView4PipelineOutput(images=image) ================================================ FILE: modules/cfgzero/flux_pipeline.py ================================================ # https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/flux/pipeline.py import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast, ) from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, FluxTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import FluxPipeline >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] >>> image.save("flux.png") ``` """ @torch.cuda.amp.autocast(dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class FluxCFGZeroPipeline( DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, FluxIPAdapterMixin, ): r""" The Flux pipeline for text-to-image generation. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: transformer ([`FluxTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = 128 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # We only use the pooled prompt output from the CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) return image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) for single_ip_adapter_image in ip_adapter_image: single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) image_embeds.append(single_image_embeds[None, :]) else: if not isinstance(ip_adapter_image_embeds, list): ip_adapter_image_embeds = [ip_adapter_image_embeds] if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) for single_image_embeds in ip_adapter_image_embeds: image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for single_image_embeds in image_embeds: single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def check_inputs( self, prompt, prompt_2, height, width, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape( latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids @property def guidance_scale(self): return self._guidance_scale @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def current_timestep(self): return self._current_timestep @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, true_cfg_scale: float = 1.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_ip_adapter_image: Optional[PipelineImageInput] = None, negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, zero_steps: Optional[int] = 0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. true_cfg_scale (`float`, *optional*, defaults to 1.0): When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. negative_ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._current_timestep = None self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if do_true_cfg: ( negative_prompt_embeds, negative_pooled_prompt_embeds, _, ) = self.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None ): negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None ): ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} image_embeds = None negative_image_embeds = None if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, ) if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: negative_image_embeds = self.prepare_ip_adapter_image_embeds( negative_ip_adapter_image, negative_ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, ) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=negative_pooled_prompt_embeds, encoder_hidden_states=negative_prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) else: if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred*0. # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() self._current_timestep = None if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(images=image) ================================================ FILE: modules/cfgzero/hidream_pipeline.py ================================================ import inspect import math from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast, T5EncoderModel, T5Tokenizer, ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import HiDreamImageLoraLoaderMixin from diffusers.models import AutoencoderKL, HiDreamImageTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler from diffusers.utils import is_torch_xla_available, logging from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.hidream_image.pipeline_output import HiDreamImagePipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM >>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel >>> scheduler = UniPCMultistepScheduler( ... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True ... ) >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", ... output_hidden_states=True, ... output_attentions=True, ... torch_dtype=torch.bfloat16, ... ) >>> transformer = HiDreamImageTransformer2DModel.from_pretrained( ... "HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16 ... ) >>> pipe = HiDreamImagePipeline.from_pretrained( ... "HiDream-ai/HiDream-I1-Full", ... scheduler=scheduler, ... tokenizer_4=tokenizer_4, ... text_encoder_4=text_encoder_4, ... transformer=transformer, ... torch_dtype=torch.bfloat16, ... ) >>> pipe.enable_model_cpu_offload() >>> image = pipe( ... 'A cat holding a sign that says "Hi-Dreams.ai".', ... height=1024, ... width=1024, ... guidance_scale=5.0, ... num_inference_steps=50, ... generator=torch.Generator("cuda").manual_seed(0), ... ).images[0] >>> image.save("output.png") ``` """ @torch.cuda.amp.autocast(dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class HiDreamImageCFGZeroPipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5Tokenizer, text_encoder_4: LlamaForCausalLM, tokenizer_4: PreTrainedTokenizerFast, transformer: HiDreamImageTransformer2DModel, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, text_encoder_4=text_encoder_4, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, tokenizer_4=tokenizer_4, scheduler=scheduler, transformer=transformer, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.default_sample_size = 128 if getattr(self, "tokenizer_4", None) is not None: self.tokenizer_4.pad_token = self.tokenizer_4.eos_token def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder_3.dtype prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode( untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1] ) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds def _get_clip_prompt_embeds( self, tokenizer, text_encoder, prompt: Union[str, List[str]], max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = tokenizer( prompt, padding="max_length", max_length=min(max_sequence_length, 218), truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {218} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds def _get_llama3_prompt_embeds( self, prompt: Union[str, List[str]] = None, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder_4.dtype prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = self.tokenizer_4( prompt, padding="max_length", max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_4.batch_decode( untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1] ) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" ) outputs = self.text_encoder_4( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True, output_attentions=True, ) prompt_embeds = outputs.hidden_states[1:] prompt_embeds = torch.stack(prompt_embeds, dim=0) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], prompt_4: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[List[torch.FloatTensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, lora_scale: Optional[float] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] prompt_embeds, pooled_prompt_embeds = self._encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4, device=device, dtype=dtype, num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt negative_prompt_4 = negative_prompt_4 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) negative_prompt_4 = ( batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_3=negative_prompt_3, prompt_4=negative_prompt_4, device=device, dtype=dtype, num_images_per_prompt=num_images_per_prompt, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, max_sequence_length=max_sequence_length, ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def _encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], prompt_4: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[List[torch.FloatTensor]] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, ): device = device or self._execution_device if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] if pooled_prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype ) pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype ) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) if prompt_embeds is None: prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_4 = prompt_4 or prompt prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) _, seq_len, _ = t5_prompt_embeds.shape t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1) t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) _, _, seq_len, dim = llama3_prompt_embeds.shape llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] return prompt_embeds, pooled_prompt_embeds def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def attention_kwargs(self): return self._attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, prompt_4: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, use_cfg_zero_star: Optional[bool] = True, use_zero_init: Optional[bool] = True, zero_steps: Optional[int] = 0, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor division = self.vae_scale_factor * 2 S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 scale = S_max / (width * height) scale = math.sqrt(scale) width, height = int(width * scale // division * division), int(height * scale // division * division) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) elif prompt_embeds is not None: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] else: batch_size = 1 device = self._execution_device lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, negative_prompt_4=negative_prompt_4, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_classifier_free_guidance: prompt_embeds_arr = [] for n, p in zip(negative_prompt_embeds, prompt_embeds): if len(n.shape) == 3: prompt_embeds_arr.append(torch.cat([n, p], dim=0)) else: prompt_embeds_arr.append(torch.cat([n, p], dim=1)) prompt_embeds = prompt_embeds_arr pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, pooled_prompt_embeds.dtype, device, generator, latents, ) if latents.shape[-2] != latents.shape[-1]: B, C, H, W = latents.shape pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) img_ids = torch.zeros(pH, pW, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] img_ids = img_ids.reshape(pH * pW, -1) img_ids_pad = torch.zeros(self.transformer.max_seq, 3) img_ids_pad[: pH * pW, :] = img_ids img_sizes = img_sizes.unsqueeze(0).to(latents.device) img_ids = img_ids_pad.unsqueeze(0).to(latents.device) if self.do_classifier_free_guidance: img_sizes = img_sizes.repeat(2 * B, 1) img_ids = img_ids.repeat(2 * B, 1, 1) else: img_sizes = img_ids = None # 5. Prepare timesteps mu = calculate_shift(self.transformer.max_seq) scheduler_kwargs = {"mu": mu} if isinstance(self.scheduler, UniPCMultistepScheduler): self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu)) timesteps = self.scheduler.timesteps else: timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timesteps=timestep, encoder_hidden_states_t5=prompt_embeds[0], encoder_hidden_states_llama3=prompt_embeds[1], pooled_embeds=pooled_prompt_embeds, # img_sizes=img_sizes, # img_ids=img_ids, return_dict=False, )[0] noise_pred = -noise_pred # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) if use_cfg_zero_star: positive_flat = noise_pred_text.view(batch_size, -1) negative_flat = noise_pred_uncond.view(batch_size, -1) alpha = optimized_scale(positive_flat,negative_flat) alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) alpha = alpha.to(positive_flat.dtype) if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred_text*0. else: noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha) else: noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) else: if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred*0. # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return HiDreamImagePipelineOutput(images=image) ================================================ FILE: modules/cfgzero/hunyuan_t2v_pipeline.py ================================================ # Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.loaders import HunyuanVideoLoraLoaderMixin from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel >>> from diffusers.utils import export_to_video >>> model_id = "hunyuanvideo-community/HunyuanVideo" >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ... ) >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") >>> output = pipe( ... prompt="A cat walks on the grass, realistic", ... height=320, ... width=512, ... num_frames=61, ... num_inference_steps=30, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=15) ``` """ @torch.cuda.amp.autocast(dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star DEFAULT_PROMPT_TEMPLATE = { "template": ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " "1. The main content and theme of the video." "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "4. background environment, light, style and atmosphere." "5. camera angles, movements, and transitions used in the video:<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" ), "crop_start": 95, } # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class HunyuanVideoCFGZeroPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): r""" Pipeline for text-to-video generation using HunyuanVideo. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: text_encoder ([`LlamaModel`]): [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). tokenizer (`LlamaTokenizer`): Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). transformer ([`HunyuanVideoTransformer3DModel`]): Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLHunyuanVideo`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder_2 ([`CLIPTextModel`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer_2 (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, text_encoder: LlamaModel, tokenizer: LlamaTokenizerFast, transformer: HunyuanVideoTransformer3DModel, vae: AutoencoderKLHunyuanVideo, scheduler: FlowMatchEulerDiscreteScheduler, text_encoder_2: CLIPTextModel, tokenizer_2: CLIPTokenizer, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, ) self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_llama_prompt_embeds( self, prompt: Union[str, List[str]], prompt_template: Dict[str, Any], num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 256, num_hidden_layers_to_skip: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) prompt = [prompt_template["template"].format(p) for p in prompt] crop_start = prompt_template.get("crop_start", None) if crop_start is None: prompt_template_input = self.tokenizer( prompt_template["template"], padding="max_length", return_tensors="pt", return_length=False, return_overflowing_tokens=False, return_attention_mask=False, ) crop_start = prompt_template_input["input_ids"].shape[-1] # Remove <|eot_id|> token and placeholder {} crop_start -= 2 max_sequence_length += crop_start text_inputs = self.tokenizer( prompt, max_length=max_sequence_length, padding="max_length", truncation=True, return_tensors="pt", return_length=False, return_overflowing_tokens=False, return_attention_mask=True, ) text_input_ids = text_inputs.input_ids.to(device=device) prompt_attention_mask = text_inputs.attention_mask.to(device=device) prompt_embeds = self.text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] prompt_embeds = prompt_embeds.to(dtype=dtype) if crop_start is not None and crop_start > 0: prompt_embeds = prompt_embeds[:, crop_start:] prompt_attention_mask = prompt_attention_mask[:, crop_start:] # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) return prompt_embeds, prompt_attention_mask def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_videos_per_prompt: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 77, ) -> torch.Tensor: device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]] = None, prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 256, ): if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( prompt, prompt_template, num_videos_per_prompt, device=device, dtype=dtype, max_sequence_length=max_sequence_length, ) if pooled_prompt_embeds is None: if prompt_2 is None: prompt_2 = prompt pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt, num_videos_per_prompt, device=device, dtype=dtype, max_sequence_length=77, ) return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask def check_inputs( self, prompt, prompt_2, height, width, prompt_embeds=None, callback_on_step_end_tensor_inputs=None, prompt_template=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if prompt_template is not None: if not isinstance(prompt_template, dict): raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") if "template" not in prompt_template: raise ValueError( f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" ) def prepare_latents( self, batch_size: int, num_channels_latents: int = 32, height: int = 720, width: int = 1280, num_frames: int = 129, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, (num_frames - 1) // self.vae_scale_factor_temporal + 1, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() @property def guidance_scale(self): return self._guidance_scale @property def num_timesteps(self): return self._num_timesteps @property def attention_kwargs(self): return self._attention_kwargs @property def current_timestep(self): return self._current_timestep @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, negative_prompt_2: Union[str, List[str]] = None, height: int = 720, width: int = 1280, num_frames: int = 129, num_inference_steps: int = 50, sigmas: List[float] = None, true_cfg_scale: float = 1.0, guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, max_sequence_length: int = 256, use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, zero_steps: Optional[int] = 0, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is not greater than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. height (`int`, defaults to `720`): The height in pixels of the generated image. width (`int`, defaults to `1280`): The width in pixels of the generated image. num_frames (`int`, defaults to `129`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. true_cfg_scale (`float`, *optional*, defaults to 1.0): When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Note that the only available HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and conditional latent is not applied. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~HunyuanVideoPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, prompt_template, ) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False device = self._execution_device # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 3. Encode input prompt transformer_dtype = self.transformer.dtype prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_template=prompt_template, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, prompt_attention_mask=prompt_attention_mask, device=device, max_sequence_length=max_sequence_length, ) prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) if do_true_cfg: negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, prompt_template=prompt_template, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=negative_pooled_prompt_embeds, prompt_attention_mask=negative_prompt_attention_mask, device=device, max_sequence_length=max_sequence_length, ) negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, num_frames, torch.float32, device, generator, latents, ) # 6. Prepare guidance condition guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] if do_true_cfg: neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_attention_mask, pooled_projections=negative_pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) else: if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred*0. # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() self._current_timestep = None if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return HunyuanVideoPipelineOutput(frames=video) ================================================ FILE: modules/cfgzero/sd3_pipeline.py ================================================ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import ( BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, PreTrainedModel, T5EncoderModel, T5TokenizerFast, ) from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import SD3Transformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusion3Pipeline >>> pipe = StableDiffusion3Pipeline.from_pretrained( ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> image = pipe(prompt).images[0] >>> image.save("sd3.png") ``` """ @torch.cuda.amp.autocast(dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusion3CFGZeroPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` as its dimension. text_encoder_2 ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. text_encoder_3 ([`T5EncoderModel`]): Frozen text-encoder. Stable Diffusion 3 uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). image_encoder (`PreTrainedModel`, *optional*): Pre-trained Vision Model for IP Adapter. feature_extractor (`BaseImageProcessor`, *optional*): Image processor for IP Adapter. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( self, transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, image_encoder: PreTrainedModel = None, feature_extractor: BaseImageProcessor = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) self.patch_size = ( self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 ) def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if self.text_encoder_3 is None: return torch.zeros( ( batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), device=device, dtype=dtype, ) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): device = device or self._execution_device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if ( height % (self.vae_scale_factor * self.patch_size) != 0 or width % (self.vae_scale_factor * self.patch_size) != 0 ): raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_3 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_3 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def skip_guidance_layers(self): return self._skip_guidance_layers @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: """Encodes the given image into a feature representation using a pre-trained image encoder. Args: image (`PipelineImageInput`): Input image to be encoded. device: (`torch.device`): Torch device. Returns: `torch.Tensor`: The encoded image feature representation. """ if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=self.dtype) return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, ) -> torch.Tensor: """Prepares image embeddings for use in the IP-Adapter. Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. Args: ip_adapter_image (`PipelineImageInput`, *optional*): The input image to extract features from for IP-Adapter. ip_adapter_image_embeds (`torch.Tensor`, *optional*): Precomputed image embeddings. device: (`torch.device`, *optional*): Torch device. num_images_per_prompt (`int`, defaults to 1): Number of images that should be generated per prompt. do_classifier_free_guidance (`bool`, defaults to True): Whether to use classifier free guidance or not. """ device = device or self._execution_device if ip_adapter_image_embeds is not None: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) else: single_image_embeds = ip_adapter_image_embeds elif ip_adapter_image is not None: single_image_embeds = self.encode_image(ip_adapter_image, device) if do_classifier_free_guidance: single_negative_image_embeds = torch.zeros_like(single_image_embeds) else: raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) return image_embeds.to(device=device) def enable_sequential_cpu_offload(self, *args, **kwargs): if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: logger.warning( "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." ) super().enable_sequential_cpu_offload(*args, **kwargs) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, skip_guidance_layers: List[int] = None, skip_layer_guidance_scale: float = 2.8, skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_start: float = 0.01, mu: Optional[float] = None, use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, zero_steps: Optional[int] = 0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used instead negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used instead num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. skip_guidance_layers (`List[int]`, *optional*): A list of integers that specify layers to skip during guidance. If not provided, all layers will be used for guidance. If provided, the guidance will only be applied to the layers specified in the list. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9]. skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers with a scale of `1`. skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in `skip_guidance_layers` will start. The guidance will be applied to the layers specified in `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: Returns: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._skip_layer_guidance_scale = skip_layer_guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_classifier_free_guidance: if skip_guidance_layers is not None: original_prompt_embeds = prompt_embeds original_pooled_prompt_embeds = pooled_prompt_embeds prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps scheduler_kwargs = {} if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: _, _, height, width = latents.shape image_seq_len = (height // self.transformer.config.patch_size) * ( width // self.transformer.config.patch_size ) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) scheduler_kwargs["mu"] = mu elif mu is not None: scheduler_kwargs["mu"] = mu timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Prepare image embeddings if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} else: self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) sigmas = timesteps / self.scheduler.config.num_train_timesteps # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) if use_cfg_zero_star: positive_flat = noise_pred_text.view(batch_size, -1) negative_flat = noise_pred_uncond.view(batch_size, -1) alpha = optimized_scale(positive_flat,negative_flat) alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) alpha = alpha.to(positive_flat.dtype) if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred_text*0. else: noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha) else: noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) should_skip_layers = ( True if i > num_inference_steps * skip_layer_guidance_start and i < num_inference_steps * skip_layer_guidance_stop else False ) if skip_guidance_layers is not None and should_skip_layers: timestep = t.expand(latents.shape[0]) latent_model_input = latents noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=original_prompt_embeds, pooled_projections=original_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, skip_layers=skip_guidance_layers, )[0] noise_pred = ( noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale ) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) ================================================ FILE: modules/cfgzero/wan_t2v_pipeline.py ================================================ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html from typing import Any, Callable, Dict, List, Optional, Union import ftfy import regex as re import torch from transformers import AutoTokenizer, UMT5EncoderModel from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.loaders import WanLoraLoaderMixin from diffusers.models import AutoencoderKLWan, WanTransformer3DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers.utils import export_to_video >>> from diffusers import AutoencoderKLWan, WanPipeline >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) >>> pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=720, ... width=1280, ... num_frames=81, ... guidance_scale=5.0, ... ).frames[0] >>> export_to_video(output, "output.mp4", fps=16) ``` """ @torch.cuda.amp.autocast(dtype=torch.float32) def optimized_scale(positive_flat, negative_flat): # Calculate dot production dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) # Squared norm of uncondition squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 st_star = dot_product / squared_norm return st_star def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text class WanCFGZeroPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" Pipeline for text-to-video generation using Wan. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: tokenizer ([`T5Tokenizer`]): Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. transformer ([`WanTransformer3DModel`]): Conditional Transformer to denoise the input latents. scheduler ([`UniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(u) for u in prompt] batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): Whether to use classifier free guidance or not. num_videos_per_prompt (`int`, *optional*, defaults to 1): Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. device: (`torch.device`, *optional*): torch device dtype: (`torch.dtype`, *optional*): torch dtype """ device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds = self._get_t5_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) return prompt_embeds, negative_prompt_embeds def check_inputs( self, prompt, negative_prompt, height, width, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, batch_size: int, num_channels_latents: int = 16, height: int = 480, width: int = 832, num_frames: int = 81, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: return latents.to(device=device, dtype=dtype) num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 shape = ( batch_size, num_channels_latents, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1.0 @property def num_timesteps(self): return self._num_timesteps @property def current_timestep(self): return self._current_timestep @property def interrupt(self): return self._interrupt @property def attention_kwargs(self): return self._attention_kwargs @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, use_cfg_zero_star: Optional[bool] = False, use_zero_init: Optional[bool] = True, zero_steps: Optional[int] = 0, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, defaults to `480`): The height in pixels of the generated image. width (`int`, defaults to `832`): The width in pixels of the generated image. num_frames (`int`, defaults to `81`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): The dtype to use for the torch.amp.autocast. Examples: Returns: [`~WanPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt, height, width, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False device = self._execution_device # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, ) transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, num_frames, torch.float32, device, generator, latents, ) # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] if self.do_classifier_free_guidance: noise_pred_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred_text = noise_pred if use_cfg_zero_star: positive_flat = noise_pred_text.view(batch_size, -1) negative_flat = noise_pred_uncond.view(batch_size, -1) alpha = optimized_scale(positive_flat,negative_flat) alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) alpha = alpha.to(noise_pred_text.dtype) if (i <= zero_steps) and use_zero_init: noise_pred = noise_pred_text*0. else: noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha) else: noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() self._current_timestep = None if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( latents.device, latents.dtype ) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return WanPipelineOutput(frames=video) ================================================ FILE: modules/civitai/api_civitai.py ================================================ from starlette.responses import JSONResponse def models_to_json(all_models:list, model_id:int=None): dct = [] for model in all_models: if model_id is not None and model.id != model_id: continue model_dct = model.__dict__.copy() versions_dct = [] for version in model.versions: version_dct = version.__dict__.copy() version_dct['files'] = [f.__dict__.copy() for f in version.files] version_dct['images'] = [i.__dict__.copy() for i in version.images] versions_dct.append(version_dct) model_dct['versions'] = versions_dct dct.append(model_dct) # obj = json.dumps(dct, indent=2, ensure_ascii=False) return dct def get_civitai( model_id:int=None, # if model_id is provided assume fetch-from-cache query:str = '', # search query or tag is required tag:str = '', # search query or tag is required types:str = '', # Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses sort:str = '', # Highest Rated, Most Downloaded, Newest period:str = '', # AllTime, Year, Month, Week, Day nsfw:bool = None, # optional:bool limit:int = 0, base:str = '', token:str = None, exact:bool = True, ): from modules.civitai import search_civitai if model_id is not None: dct = models_to_json(search_civitai.models, model_id=model_id) return JSONResponse(content=dct, status_code=200) if len(query) > 0 or len(tag) > 0: models = search_civitai.search_civitai( query=query, tag=tag, types=types, sort=sort, period=period, nsfw=nsfw, limit=limit, base=base, token=token, exact=exact ) dct = models_to_json(models) return JSONResponse(content=dct, status_code=200) return JSONResponse(content=[], status_code=200) def post_civitai(page:str=None): from modules.civitai import metadata_civitai result = [] for r in metadata_civitai.civit_search_metadata(title=page, raw=True): result = r # get the last yielded result return result def register_api(): from modules.shared import api api.add_api_route("/sdapi/v1/civitai", get_civitai, methods=["GET"], response_model=list) api.add_api_route("/sdapi/v1/civitai", post_civitai, methods=["POST"], response_model=list) ================================================ FILE: modules/civitai/download_civitai.py ================================================ import os import json import rich.progress as p from PIL import Image from modules import shared, errors, paths pbar = None def save_video_frame(filepath: str): from modules import video try: frames, fps, duration, w, h, codec, frame = video.get_video_params(filepath, capture=True) except Exception as e: shared.log.error(f'Video: file={filepath} {e}') return None if frame is not None: basename = os.path.splitext(filepath) thumb = f'{basename[0]}.thumb.jpg' shared.log.debug(f'Video: file={filepath} frames={frames} fps={fps} size={w}x{h} codec={codec} duration={duration} thumb={thumb}') frame.save(thumb) else: shared.log.error(f'Video: file={filepath} no frames found') return frame def download_civit_meta(model_path: str, model_id): fn = os.path.splitext(model_path)[0] + '.json' url = f'https://civitai.com/api/v1/models/{model_id}' r = shared.req(url) if r.status_code == 200: try: data = r.json() shared.writefile(data, filename=fn, mode='w', silent=True) shared.log.info(f'CivitAI download: id={model_id} url={url} file="{fn}"') return r.status_code, len(data), '' # code/size/note except Exception as e: errors.display(e, 'civitai meta') shared.log.error(f'CivitAI meta: id={model_id} url={url} file="{fn}" {e}') return r.status_code, '', str(e) return r.status_code, '', '' def download_civit_preview(model_path: str, preview_url: str): global pbar # pylint: disable=global-statement if model_path is None: pbar = None return 500, '', '' ext = os.path.splitext(preview_url)[1] preview_file = os.path.splitext(model_path)[0] + ext is_video = preview_file.lower().endswith('.mp4') is_json = preview_file.lower().endswith('.json') if is_json: shared.log.warning(f'CivitAI download: url="{preview_url}" skip json') return 500, '', 'exepected preview image got json' if os.path.exists(preview_file): return 304, '', 'already exists' # res = f'CivitAI download: url={preview_url} file="{preview_file}"' r = shared.req(preview_url, stream=True) total_size = int(r.headers.get('content-length', 0)) block_size = 16384 # 16KB blocks written = 0 img = None jobid = shared.state.begin('Download CivitAI') if pbar is None: pbar = p.Progress(p.TextColumn('[cyan]Download'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn(), p.TextColumn('[yellow]{task.description}'), console=shared.console) try: with open(preview_file, 'wb') as f: with pbar: task = pbar.add_task(description=preview_file, total=total_size) for data in r.iter_content(block_size): written = written + len(data) f.write(data) pbar.update(task, advance=block_size) if written < 1024: # min threshold os.remove(preview_file) return 400, '', 'removed invalid download' if is_video: img = save_video_frame(preview_file) else: img = Image.open(preview_file) except Exception as e: shared.log.error(f'CivitAI download error: url={preview_url} file="{preview_file}" written={written} {e}') shared.state.end(jobid) return 500, '', str(e) shared.state.end(jobid) if img is None: return 500, '', 'image is none' shared.log.info(f'CivitAI download: url={preview_url} file="{preview_file}" size={total_size} image={img.size}') img.close() return 200, str(total_size), '' # code/size/note def download_civit_model_thread(model_name: str, model_url: str, model_path: str = "", model_type: str = "Model", token: str = None): import hashlib sha256 = hashlib.sha256() sha256.update(model_url.encode('utf-8')) temp_file = sha256.hexdigest()[:8] + '.tmp' headers = {} starting_pos = 0 if os.path.isfile(temp_file): starting_pos = os.path.getsize(temp_file) headers['Range'] = f'bytes={starting_pos}-' if 'civit' in model_url.lower(): # downloader can be used for other urls too if token is None or len(token) == 0: token = shared.opts.civitai_token if (token is not None) and (len(token) > 0): headers['Authorization'] = f'Bearer {token}' r = shared.req(model_url, headers=headers, stream=True) total_size = int(r.headers.get('content-length', 0)) if model_name is None or len(model_name) == 0: cn = r.headers.get('content-disposition', '') model_name = cn.split('filename=')[-1].strip('"') model_path = model_path.strip() if len(model_path) > 0: if os.path.isabs(model_path): pass else: model_path = os.path.join(paths.models_path, model_path) elif model_type.lower() == 'lora': model_path = shared.opts.lora_dir elif model_type.lower() == 'embedding': model_path = shared.opts.embeddings_dir elif model_type.lower() == 'vae': model_path = shared.opts.vae_dir else: model_path = shared.opts.ckpt_dir model_file = os.path.join(model_path, model_name) temp_file = os.path.join(model_path, temp_file) res = f'Model download: name="{model_name}" url="{model_url}" path="{model_path}" temp="{temp_file}"' if os.path.isfile(model_file): res += ' already exists' shared.log.warning(res) return res res += f' size={round((starting_pos + total_size)/1024/1024, 2)}Mb' shared.log.info(res) jobid = shared.state.begin('Download CivitAI') block_size = 16384 # 16KB blocks written = starting_pos global pbar # pylint: disable=global-statement if pbar is None: pbar = p.Progress(p.TextColumn('[cyan]{task.description}'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn(), p.TextColumn('[cyan]{task.fields[name]}'), console=shared.console) with pbar: task = pbar.add_task(description="Download starting", total=starting_pos+total_size, name=model_name) try: with open(temp_file, 'ab') as f: for data in r.iter_content(block_size): if written == 0: try: # check if response is JSON message instead of bytes shared.log.error(f'Model download: response={json.loads(data.decode("utf-8"))}') raise ValueError('response: type=json expected=bytes') except Exception: # this is good pass written = written + len(data) f.write(data) pbar.update(task, description="Download", completed=written) if written < 1024: # min threshold os.remove(temp_file) raise ValueError(f'removed invalid download: bytes={written}') except Exception as e: shared.log.error(f'{res} {e}') finally: pbar.stop_task(task) pbar.remove_task(task) if starting_pos+total_size != written: shared.log.warning(f'{res} written={round(written/1024/1024)}Mb incomplete download') elif os.path.exists(temp_file): shared.log.debug(f'Model download complete: temp="{temp_file}" path="{model_file}"') os.rename(temp_file, model_file) shared.state.end(jobid) if os.path.exists(model_file): return model_file else: return None def download_civit_model(model_url: str, model_name: str = '', model_path: str = '', model_type: str = '', token: str = None): import threading if model_url is None or len(model_url) == 0: shared.log.error('Model download: no url provided') return thread = threading.Thread(target=download_civit_model_thread, args=(model_name, model_url, model_path, model_type, token)) thread.start() thread.join() from modules.sd_models import list_models # pylint: disable=W0621 list_models() ================================================ FILE: modules/civitai/metadata_civitai.py ================================================ import os import re import time import gradio as gr from modules.shared import log, opts, req, readfile, max_workers data = [] selected_model = None class CivitModel: def __init__(self, name, fn, sha = None, meta = {}): self.name = name self.file = name self.id = meta.get('id', 0) self.fn = fn self.sha = sha self.meta = meta self.versions = 0 self.vername = '' self.latest = '' self.latest_hashes = [] self.latest_name = '' self.url = None self.status = 'Not found' def civit_update_metadata(raw:bool=False): def create_update_metadata_table(rows: list[CivitModel]): html = """ {tbody}
FileIDNameHashVersionsLatestStatus
""" tbody = '' for row in rows: try: tbody += f""" {row.file} {row.id} {row.name} {row.sha} {row.versions} {row.latest} {row.status} """ except Exception as e: log.error(f'Model list: row={row} {e}') return html.format(tbody=tbody) log.debug('CivitAI update metadata: models') from modules import ui_extra_networks from modules.civitai.download_civitai import download_civit_meta pages = ui_extra_networks.get_pages('Model') if len(pages) == 0: return 'CivitAI update metadata: no models found' page: ui_extra_networks.ExtraNetworksPage = pages[0] results = [] all_hashes = [(item.get('hash', None) or 'XXXXXXXX').upper()[:8] for item in page.list_items()] for item in page.list_items(): model = CivitModel(name=item['name'], fn=item['filename'], sha=item.get('hash', None), meta=item.get('metadata', {})) if model.sha is None or len(model.sha) == 0: log.debug(f'CivitAI skip search: name="{model.name}" hash=None') else: r = req(f'https://civitai.com/api/v1/model-versions/by-hash/{model.sha}') log.debug(f'CivitAI search: name="{model.name}" hash={model.sha} status={r.status_code}') if r.status_code == 200: d = r.json() model.id = d['modelId'] download_civit_meta(model.fn, model.id) fn = os.path.splitext(item['filename'])[0] + '.json' model.meta = readfile(fn, silent=True, as_type="dict") model.name = model.meta.get('name', model.name) model.versions = len(model.meta.get('modelVersions', [])) versions = model.meta.get('modelVersions', []) if len(versions) > 0: model.latest = versions[0].get('name', '') model.latest_hashes.clear() for v in versions[0].get('files', []): for h in v.get('hashes', {}).values(): model.latest_hashes.append(h[:8].upper()) for ver in versions: for f in ver.get('files', []): for h in f.get('hashes', {}).values(): if h[:8].upper() == model.sha[:8].upper(): model.vername = ver.get('name', '') model.url = f.get('downloadUrl', None) model.latest_name = f.get('name', '') if model.vername == model.latest: model.status = 'Latest version' elif any(map(lambda v: v in model.latest_hashes, all_hashes)): # pylint: disable=cell-var-from-loop # noqa: C417 model.status = 'Update downloaded' else: model.status = 'Update available' break results.append(model) yield results if raw else create_update_metadata_table(results) yield results if raw else create_update_metadata_table(results) def civit_search_model(name, tag, model_type): # types = 'LORA' if model_type == 'LoRA' else 'Checkpoint' url = 'https://civitai.com/api/v1/models?limit=25&Sort=Newest' if model_type == 'Model': url += '&types=Checkpoint' elif model_type == 'LoRA': url += '&types=LORA&types=DoRA&types=LoCon' elif model_type == 'Embedding': url += '&types=TextualInversion' elif model_type == 'VAE': url += '&types=VAE' if name is not None and len(name) > 0: url += f'&query={name}' if tag is not None and len(tag) > 0: url += f'&tag={tag}' r = req(url) log.debug(f'CivitAI search: type={model_type} name="{name}" tag={tag or "none"} url="{url}" status={r.status_code}') if r.status_code != 200: log.warning(f'CivitAI search: name="{name}" tag={tag} status={r.status_code}') return [], gr.update(visible=False, value=[]), gr.update(visible=False, value=None), gr.update(visible=False, value=None) try: body = r.json() except Exception as e: log.error(f'CivitAI search: name="{name}" tag={tag} {e}') return [], gr.update(visible=False, value=[]), gr.update(visible=False, value=None), gr.update(visible=False, value=None) global data # pylint: disable=global-statement data = body.get('items', []) data1 = [] for model in data: found = 0 if model_type == 'LoRA' and model['type'].lower() in ['lora', 'locon', 'dora', 'lycoris']: found += 1 elif model_type == 'Embedding' and model['type'].lower() in ['textualinversion', 'embedding']: found += 1 elif model_type == 'Model' and model['type'].lower() in ['checkpoint']: found += 1 elif model_type == 'VAE' and model['type'].lower() in ['vae']: found += 1 elif model_type == 'Other': found += 1 if found > 0: data1.append([ model['id'], model['name'], ', '.join(model['tags']), model['stats']['downloadCount'], model['stats']['rating'] ]) res = f'Search result: name={name} tag={tag or "none"} type={model_type} models={len(data1)}' return res, gr.update(visible=len(data1) > 0, value=data1 if len(data1) > 0 else []), gr.update(visible=False, value=None), gr.update(visible=False, value=None) def atomic_civit_search_metadata(item, results): from modules.civitai.download_civitai import download_civit_preview, download_civit_meta if item is None: return try: meta = os.path.splitext(item['filename'])[0] + '.json' except Exception: # log.error(f'CivitAI search metadata: item={item} {e}') return has_meta = os.path.isfile(meta) and os.stat(meta).st_size > 0 if ('missing.png' in item['preview'] or not has_meta) and os.path.isfile(item['filename']): sha = item.get('hash', None) found = False result = { 'id': '', 'name': item['name'], 'type': '', 'hash': '', 'code': '', 'size': '', 'note': '', } if sha is not None and len(sha) > 0: r = req(f'https://civitai.com/api/v1/model-versions/by-hash/{sha}') log.debug(f'CivitAI search: name="{item["name"]}" hash={sha} status={r.status_code}') result['hash'] = sha result['code'] = r.status_code if r.status_code == 200: d = r.json() result['code'], result['size'], result['note'] = download_civit_meta(item['filename'], d['modelId']) result['id'] = d['modelId'] result['type'] = 'metadata' results.append(result) if d.get('images') is not None: for i in d['images']: result['code'], result['size'], result['note'] = download_civit_preview(item['filename'], i['url']) if result['code'] == 200: result['type'] = 'preview' results.append(result) found = True break if not found and os.stat(item['filename']).st_size < (1024 * 1024 * 1024): from modules import hashes sha = hashes.calculate_sha256(item['filename'], quiet=True)[:10] r = req(f'https://civitai.com/api/v1/model-versions/by-hash/{sha}') log.debug(f'CivitAI search: name="{item["name"]}" hash={sha} status={r.status_code}') result['hash'] = sha result['code'] = r.status_code if r.status_code == 200: d = r.json() result['code'], result['size'], result['note'] = download_civit_meta(item['filename'], d['modelId']) result['id'] = d['modelId'] result['type'] = 'metadata' results.append(result) if d.get('images') is not None: for i in d['images']: result['code'], result['size'], result['note'] = download_civit_preview(item['filename'], i['url']) if result['code'] == 200: result['type'] = 'preview' results.append(result) found = True break if not found: results.append(result) def civit_search_metadata(title: str = None, raw: bool = False): def create_search_metadata_table(rows): html = """ {tbody}
NameIDTypeCodeHashSizeNote
""" tbody = '' for row in rows: try: tbody += f""" {row['name']} {row['id']} {row['type']} {row['code']} {row['hash']} {row['size']} {row['note']} """ except Exception as e: log.error(f'Model list: row={row} {e}') return html.format(tbody=tbody) from modules.ui_extra_networks import get_pages results = [] scanned, skipped = 0, 0 t0 = time.time() candidates = [] re_skip = [r.strip() for r in opts.extra_networks_scan_skip.split(',') if len(r.strip()) > 0] for page in get_pages(): if type(title) == str: if page.title.lower() != title.lower(): continue if page.name == 'style' or page.name == 'wildcards': continue for item in page.list_items(): if item is None: continue if any(re.search(re_str, item.get('name', '') + item.get('filename', '')) for re_str in re_skip): skipped += 1 continue scanned += 1 candidates.append(item) log.debug(f'CivitAI search metadata: type={title if type(title) == str else "all"} workers={max_workers} skip={len(re_skip)} items={len(candidates)}') import concurrent with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_items = {} for fn in candidates: future_items[executor.submit(atomic_civit_search_metadata, fn, results)] = fn for future in concurrent.futures.as_completed(future_items): future.result() yield results if raw else create_search_metadata_table(results) t1 = time.time() log.debug(f'CivitAI search metadata: scanned={scanned} skipped={skipped} time={t1-t0:.2f}') yield results if raw else create_search_metadata_table(results) ================================================ FILE: modules/civitai/search_civitai.py ================================================ from dataclasses import dataclass import os import json import time from installer import install, log full_dct = False full_html = False base_models = ['', 'AuraFlow', 'Chroma', 'CogVideoX', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Krea', 'Flux.1 Kontext', 'Flux.2 D', 'HiDream', 'Hunyuan 1', 'Hunyuan Video', 'Illustrious', 'Kolors', 'LTXV', 'Lumina', 'Mochi', 'NoobAI', 'PixArt a', 'PixArt E', 'Pony', 'Pony V7', 'Qwen', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.1', 'SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'Wan Video 2.2 TI2V-5B', 'Wan Video 2.2 I2V-A14B', 'Wan Video 2.2 T2V-A14B', 'Wan Video 2.5 T2V', 'Wan Video 2.5 I2V', 'ZImageTurbo', 'Other'] @dataclass class ModelImage(): def __init__(self, dct: dict): if isinstance(dct, str): dct = json.loads(dct) self.id: int = dct.get('id', 0) self.url: str = dct.get('url', '') self.width: int = dct.get('width', 0) self.height: int = dct.get('height', 0) self.type: str = dct.get('type', 'Unknown') self.dct: dict = dct if full_dct else {} def __str__(self): return f'ModelImage(id={self.id} url="{self.url}" width={self.width} height={self.height} type="{self.type}")' @dataclass class ModelFile(): def __init__(self, dct: dict): if isinstance(dct, str): dct = json.loads(dct) self.id: int = dct.get('id', 0) self.size: int = int(1024 * dct.get('sizeKB', 0)) self.name: str = dct.get('name', 'Unknown') self.type: str = dct.get('type', 'Unknown') self.hashes: list[str] = [str(h) for h in dct.get('hashes', {}).values()] self.url: str = dct.get('downloadUrl', '') self.dct: dict = dct if full_dct else {} def __str__(self): return f'ModelFile(id={self.id} name="{self.name}" size={self.size} type="{self.type}" url="{self.url}")' @dataclass class ModelVersion(): def __init__(self, dct: dict): import bs4 if isinstance(dct, str): dct = json.loads(dct) self.id: int = dct.get('id', 0) self.name: str = dct.get('name', 'Unknown') self.base: str = dct.get('baseModel', 'Unknown') self.mtime: str = dct.get('publishedAt', '') self.downloads: int = dct.get('stats', {}).get('downloadCount', 0) self.availability: str = dct.get('availability', 'Unknown') self.html: str = dct.get('description', '') or '' if full_html else '' self.desc: str = bs4.BeautifulSoup(dct.get('description', '') or '', features="html.parser").get_text() self.files = [ModelFile(f) for f in dct.get('files', [])] self.images = [ModelImage(i) for i in dct.get('images', [])] self.dct: dict = dct if full_dct else {} def __str__(self): return f'ModelVersion(id={self.id} name="{self.name}" base="{self.base}" mtime="{self.mtime}" downloads={self.downloads} availability={self.availability} desc="{self.desc[:30]}...")' @dataclass class Model(): def __init__(self, dct: dict): import bs4 if isinstance(dct, str): dct = json.loads(dct) self.id: int = dct.get('id', 0) self.url: str = f'https://civitai.com/models/{self.id}' self.type: str = dct.get('type', 'Unknown') self.name: str = dct.get('name', 'Unknown') self.html: str = dct.get('description', '') or '' if full_html else '' self.desc: str = bs4.BeautifulSoup(dct.get('description', '') or '', features="html.parser").get_text() self.tags: list[str] = dct.get('tags', []) self.nsfw: bool = dct.get('nsfw', False) self.level: str = dct.get('nsfwLevel', 0) self.availability: str = dct.get('availability', 'Unknown') self.downloads: int = dct.get('stats', {}).get('downloadCount', 0) self.creator: str = dct.get('creator', {}).get('username', 'Unknown') self.versions: list[ModelVersion] = [ModelVersion(v) for v in dct.get('modelVersions', [])] self.dct: dict = dct if full_dct else {} def __str__(self): return f'Model(id={self.id} type={self.type} name="{self.name}" versions={len(self.versions)} nsfw={self.nsfw}/{self.level} downloads={self.downloads} author="{self.creator}" tags={self.tags} desc="{self.desc[:30]}...")' models: list[Model] = [] # global cache for civitai search results def search_civitai( query:str, tag:str = '', # optional:tag name types:str = '', # (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses) sort:str = '', # (Highest Rated, Most Downloaded, Newest) period:str = '', # (AllTime, Year, Month, Week, Day) nsfw:bool = None, # optional:bool limit:int = 0, base:str = '', # list token:str = None, exact:bool = True, ): global models # pylint: disable=global-statement import requests from urllib.parse import urlencode install('beautifulsoup4') if len(query) == 0: log.error('CivitAI: empty query') return [] t0 = time.time() dct = { 'query': query } if len(tag) > 0: dct['tag'] = tag if nsfw is not None: dct['nsfw'] = 'true' if nsfw else 'false' if limit > 0: dct['limit'] = limit if len(types) > 0: dct['types'] = types if len(sort) > 0: dct['sort'] = sort if len(period) > 0: dct['period'] = period if len(base) > 0: dct['baseModels'] = base encoded = urlencode(dct) headers = {} if token is None: token = os.environ.get('CIVITAI_TOKEN', None) if token is not None and len(token) > 0: headers['Authorization'] = f'Bearer {token}' url = 'https://civitai.com/api/v1/models' if query.isnumeric(): uri = f'{url}/{query}' else: uri = f'{url}?{encoded}' log.info(f'CivitAI request: uri="{uri}" dct={dct} token={token is not None}') result = requests.get(uri, headers=headers, timeout=60) if result.status_code != 200: log.error(f'CivitAI: code={result.status_code} reason={result.reason} uri={result.url}') return [] all_models: list[Model] = [] exact_models: list[Model] = [] dct = result.json() if 'items' not in dct: items = [dct] # single model else: items = dct.get('items', []) for item in items: all_models.append(Model(item)) if exact: for model in all_models: model_names = [model.name.lower()] version_names = [v.name.lower() for v in model.versions] file_names = [f.name.lower() for v in model.versions for f in v.files] if any([query.lower() in name for name in model_names + version_names + file_names]): # noqa: C419 # pylint: disable=use-a-generator exact_models.append(model) t1 = time.time() log.info(f'CivitAI result: code={result.status_code} exact={len(exact_models)} total={len(models)} time={t1-t0:.2f}') models = exact_models if len(exact_models) > 0 else all_models return models def create_model_cards(all_models: list[Model]) -> str: details = """
""" cards = """
{cards}
""" card = """
{name}
{type}
{name}
""" all_cards = '' for model in all_models: previews = [] for version in model.versions: for image in version.images: if image.url and len(image.url) > 0 and not image.url.lower().endswith('.mp4'): previews.append(image.url) if len(previews) == 0: previews = ['/sdapi/v1/network/thumb?filename=html/missing.png'] all_cards += card.format(id=model.id, name=model.name, type=model.type, preview=previews[0]) html = details + cards.format(cards=all_cards) return html def print_models(all_models: list[Model]): for model in all_models: log.info(f' {model}') log.trace('Model', model.dct) for version in model.versions: log.info(f' {version}') log.trace('ModelVersion', version.dct) for file in version.files: log.info(f' {file}') log.trace('ModelFile', file.dct) for image in version.images: log.info(f' {image}') log.trace('ModelImage', image.dct) ================================================ FILE: modules/cmd_args.py ================================================ import os import sys import argparse from modules.paths import data_path, models_path parsed = None parser = argparse.ArgumentParser(description="SD.Next", conflict_handler='resolve', epilog='For other options see UI Settings page', prog='', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200)) parser._optionals = parser.add_argument_group('Other options') # pylint: disable=protected-access def parse_args(): global parsed # pylint: disable=global-statement if parsed is None: parsed, _ = parser.parse_known_args() return parsed def main_args(): # main server args group_config = parser.add_argument_group('Configuration') group_config.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s") group_config.add_argument("--ui-config", type=str, default=os.environ.get("SD_UICONFIG", os.path.join(data_path, 'ui-config.json')), help="Use specific UI configuration file, default: %(default)s") group_config.add_argument("--freeze", default=os.environ.get("SD_FREEZE", False), action='store_true', help="Disable editing settings") group_config.add_argument("--medvram", default=os.environ.get("SD_MEDVRAM", False), action='store_true', help="Split model stages and keep only active part in VRAM, default: %(default)s") group_config.add_argument("--lowvram", default=os.environ.get("SD_LOWVRAM", False), action='store_true', help="Split model components and keep only active part in VRAM, default: %(default)s") group_config.add_argument("--disable", default=os.environ.get("SD_DISABLE", ''), help="Disable specific UI tabs: %(default)s") def compatibility_args(): # removed args are added here as hidden in fixed format for compatbility reasons group_compat = parser.add_argument_group('Compatibility options') group_compat.add_argument('--backend', type=str, choices=['diffusers', 'original'], help=argparse.SUPPRESS) group_compat.add_argument('--hypernetwork-dir', default=os.path.join(models_path, 'hypernetworks'), help=argparse.SUPPRESS) group_compat.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--enable_insecure_extension_access", default=os.environ.get("SD_INSECURE", False), action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--use-cpu", nargs='+', default=[], type=str.lower, help=argparse.SUPPRESS) group_compat.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui group_compat.add_argument('--vae', type=str, default=os.environ.get("SD_VAE", None), help=argparse.SUPPRESS) group_compat.add_argument("--ui-settings-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'config.json')) group_compat.add_argument("--ui-config-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'ui-config.json')) group_compat.add_argument("--hide-ui-dir-config", action='store_true', help=argparse.SUPPRESS, default=False) group_compat.add_argument("--disable-console-progressbars", action='store_true', help=argparse.SUPPRESS, default=True) group_compat.add_argument("--disable-safe-unpickle", action='store_true', help=argparse.SUPPRESS, default=True) group_compat.add_argument("--lowram", action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--disable-extension-access", default=False, action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--api", action='store_true', help=argparse.SUPPRESS, default=True) group_compat.add_argument("--api-auth", type=str, help=argparse.SUPPRESS, default=None) group_compat.add_argument('--api-only', default=False, help=argparse.SUPPRESS) group_compat.add_argument("--disable-queue", default=os.environ.get("SD_DISABLEQUEUE", False), action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--no-hashing", default=os.environ.get("SD_NOHASHING", False), action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--no-metadata", default=os.environ.get("SD_NOMETADATA", False), action='store_true', help=argparse.SUPPRESS) def settings_args(opts, args): # removed args are added here as hidden in fixed format for compatbility reasons group_compat = parser.add_argument_group('Compatibility options') group_compat.add_argument("--allow-code", default=os.environ.get("SD_ALLOWCODE", False), action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--use-cpu", nargs='+', default=[], type=str.lower, help=argparse.SUPPRESS) group_compat.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui group_compat.add_argument('--vae', type=str, default=os.environ.get("SD_VAE", None), help=argparse.SUPPRESS) group_compat.add_argument("--ui-settings-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'config.json')) group_compat.add_argument("--ui-config-file", type=str, help=argparse.SUPPRESS, default=os.path.join(data_path, 'ui-config.json')) group_compat.add_argument("--hide-ui-dir-config", action='store_true', help=argparse.SUPPRESS, default=False) group_compat.add_argument("--disable-console-progressbars", action='store_true', help=argparse.SUPPRESS, default=True) group_compat.add_argument("--disable-safe-unpickle", action='store_true', help=argparse.SUPPRESS, default=True) group_compat.add_argument("--lowram", action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--disable-extension-access", default=False, action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--allowed-paths", nargs='+', default=[], type=str, required=False, help="add additional paths to paths allowed for web access") group_compat.add_argument("--api", action='store_true', help=argparse.SUPPRESS, default=True) group_compat.add_argument("--api-auth", type=str, help=argparse.SUPPRESS, default=None) # removed args that have been moved to opts are added here as hidden with default values as defined in opts group_compat.add_argument("--ckpt-dir", type=str, help=argparse.SUPPRESS, default=opts.ckpt_dir) group_compat.add_argument("--vae-dir", type=str, help=argparse.SUPPRESS, default=opts.vae_dir) group_compat.add_argument("--embeddings-dir", type=str, help=argparse.SUPPRESS, default=opts.embeddings_dir) group_compat.add_argument("--embeddings-templates-dir", type=str, help=argparse.SUPPRESS, default=opts.embeddings_templates_dir) group_compat.add_argument("--codeformer-models-path", type=str, help=argparse.SUPPRESS, default=opts.codeformer_models_path) group_compat.add_argument("--gfpgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.gfpgan_models_path) group_compat.add_argument("--esrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.esrgan_models_path) group_compat.add_argument("--bsrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.bsrgan_models_path) group_compat.add_argument("--realesrgan-models-path", type=str, help=argparse.SUPPRESS, default=opts.realesrgan_models_path) group_compat.add_argument("--scunet-models-path", help=argparse.SUPPRESS, default=opts.scunet_models_path) group_compat.add_argument("--swinir-models-path", help=argparse.SUPPRESS, default=opts.swinir_models_path) group_compat.add_argument("--ldsr-models-path", help=argparse.SUPPRESS, default=opts.ldsr_models_path) group_compat.add_argument("--clip-models-path", type=str, help=argparse.SUPPRESS, default=opts.clip_models_path) group_compat.add_argument("--opt-channelslast", help=argparse.SUPPRESS, action='store_true', default=opts.opt_channelslast) group_compat.add_argument("--xformers", default=(opts.cross_attention_optimization == "xFormers"), action='store_true', help=argparse.SUPPRESS) group_compat.add_argument("--disable-nan-check", help=argparse.SUPPRESS, action='store_true', default=opts.disable_nan_check) group_compat.add_argument("--rollback-vae", help=argparse.SUPPRESS, default=opts.rollback_vae) group_compat.add_argument("--no-half", help=argparse.SUPPRESS, action='store_true', default=opts.no_half) group_compat.add_argument("--no-half-vae", help=argparse.SUPPRESS, action='store_true', default=opts.no_half_vae) group_compat.add_argument("--precision", help=argparse.SUPPRESS, default=opts.precision) group_compat.add_argument("--sub-quad-q-chunk-size", help=argparse.SUPPRESS, default=opts.sub_quad_q_chunk_size) group_compat.add_argument("--sub-quad-kv-chunk-size", help=argparse.SUPPRESS, default=opts.sub_quad_kv_chunk_size) group_compat.add_argument("--sub-quad-chunk-threshold", help=argparse.SUPPRESS, default=opts.sub_quad_chunk_threshold) group_compat.add_argument("--lora-dir", help=argparse.SUPPRESS, default=opts.lora_dir) group_compat.add_argument("--embeddings-dir", help=argparse.SUPPRESS, default=opts.embeddings_dir) group_compat.add_argument("--enable-console-prompts", help=argparse.SUPPRESS, action='store_true', default=False) group_compat.add_argument("--safe", help=argparse.SUPPRESS, action='store_true', default=False) group_compat.add_argument("--use-xformers", help=argparse.SUPPRESS, action='store_true', default=False) # removed opts are added here with fixed values for compatibility reasons opts.use_old_emphasis_implementation = False opts.use_old_karras_scheduler_sigmas = False opts.no_dpmpp_sde_batch_determinism = False opts.lora_apply_to_outputs = False opts.do_not_show_images = False opts.add_model_hash_to_info = True opts.add_model_name_to_info = True opts.js_modal_lightbox = True opts.js_modal_lightbox_initially_zoomed = True opts.show_progress_in_title = False opts.sd_vae_as_default = True opts.enable_emphasis = True opts.enable_batch_seeds = True # opts.multiple_tqdm = False opts.print_hypernet_extra = False opts.dimensions_and_batch_together = True opts.enable_pnginfo = True opts.data['clip_skip'] = 1 opts.onchange("lora_dir", lambda: setattr(args, "lora_dir", opts.lora_dir)) if "USED_VSCODE_COMMAND_PICKARGS" in os.environ: import shlex argv = shlex.split(" ".join(sys.argv[1:])) if "USED_VSCODE_COMMAND_PICKARGS" in os.environ else sys.argv[1:] args = parser.parse_args(argv) else: args = parser.parse_args() return args main_args() compatibility_args() ================================================ FILE: modules/control/proc/__init__.py ================================================ ================================================ FILE: modules/control/proc/canny.py ================================================ import warnings import cv2 import numpy as np from PIL import Image from modules.control.util import HWC3, resize_image class CannyDetector: def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs): if "img" in kwargs: warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) input_image = kwargs.pop("img") if input_image is None: raise ValueError("input_image must be defined.") if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) output_type = output_type or "pil" else: output_type = output_type or "np" input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) detected_map = cv2.Canny(input_image, low_threshold, high_threshold) detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/depth_anything/__init__.py ================================================ import cv2 import torch import torch.nn.functional as F import numpy as np from PIL import Image from modules import devices, masking from modules.shared import opts class DepthAnythingDetector: """https://github.com/LiheYoung/Depth-Anything""" def __init__(self, model): from torchvision.transforms import Compose from modules.control.proc.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet self.model = model self.transform = Compose([ Resize( width=518, height=518, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method="lower_bound", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet()]) @classmethod def from_pretrained(cls, pretrained_model_or_path: str, cache_dir: str, local_files_only=False) -> str: from modules.control.proc.depth_anything.dpt import DPT_DINOv2 import huggingface_hub as hf model = ( DPT_DINOv2( encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=False, ) .to(devices.device) .eval() ) model_path = hf.hf_hub_download(repo_id=pretrained_model_or_path, filename="pytorch_model.bin", cache_dir=cache_dir, local_files_only=local_files_only) model_dict = torch.load(model_path) model.load_state_dict(model_dict) return cls(model) def __call__(self, image, color_map: str = "none", output_type: str = 'pil'): self.model.to(devices.device) if isinstance(image, Image.Image): image = np.array(image) h, w = image.shape[:2] image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 image = self.transform({ "image": image })["image"] image = torch.from_numpy(image).unsqueeze(0).to(devices.device) with devices.inference_context(): depth = self.model(image) if opts.control_move_processor: self.model.to('cpu') depth = F.interpolate(depth[None], (h, w), mode="bilinear", align_corners=False)[0, 0] depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.cpu().numpy().astype(np.uint8) if color_map != 'none': depth = cv2.applyColorMap(depth, masking.COLORMAP.index(color_map))[:, :, ::-1] if output_type == "pil": depth = Image.fromarray(depth) return depth # def unload_model(self): # self.model.to("cpu") ================================================ FILE: modules/control/proc/depth_anything/blocks.py ================================================ import torch.nn as nn def _make_scratch(in_shape, out_shape, groups=1, expand=False): scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape if len(in_shape) >= 4: out_shape4 = out_shape if expand: out_shape1 = out_shape out_shape2 = out_shape*2 out_shape3 = out_shape*4 if len(in_shape) >= 4: out_shape4 = out_shape*8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) if len(in_shape) >= 4: scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) return scratch class ResidualConvUnit(nn.Module): """Residual convolution module. """ def __init__(self, features, activation, bn): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups=1 self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups ) if self.bn is True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.bn is True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn is True: out = self.bn2(out) if self.groups > 1: out = self.conv_merge(out) return self.skip_add.add(out, x) class FeatureFusionBlock(nn.Module): """Feature fusion block. """ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups=1 self.expand = expand out_features = features if self.expand is True: out_features = features//2 self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit(features, activation, bn) self.resConfUnit2 = ResidualConvUnit(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional() self.size=size def forward(self, *xs, size=None): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) output = self.resConfUnit2(output) if (size is None) and (self.size is None): modifier = {"scale_factor": 2} elif size is None: modifier = {"size": self.size} else: modifier = {"size": size} output = nn.functional.interpolate( output, **modifier, mode="bilinear", align_corners=self.align_corners ) output = self.out_conv(output) return output ================================================ FILE: modules/control/proc/depth_anything/dpt.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin from modules.control.proc.depth_anything.blocks import FeatureFusionBlock, _make_scratch def _make_fusion_block(features, use_bn, size = None): return FeatureFusionBlock( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, size=size, ) class DPTHead(nn.Module): def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=None, use_clstoken=False): if out_channels is None: out_channels = [256, 512, 1024, 1024] super(DPTHead, self).__init__() self.nclass = nclass self.use_clstoken = use_clstoken self.projects = nn.ModuleList([ nn.Conv2d( in_channels=in_channels, out_channels=out_channel, kernel_size=1, stride=1, padding=0, ) for out_channel in out_channels ]) self.resize_layers = nn.ModuleList([ nn.ConvTranspose2d( in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0), nn.ConvTranspose2d( in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0), nn.Identity(), nn.Conv2d( in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1) ]) if use_clstoken: self.readout_projects = nn.ModuleList() for _ in range(len(self.projects)): self.readout_projects.append( nn.Sequential( nn.Linear(2 * in_channels, in_channels), nn.GELU())) self.scratch = _make_scratch( out_channels, features, groups=1, expand=False, ) self.scratch.stem_transpose = None self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) self.scratch.refinenet3 = _make_fusion_block(features, use_bn) self.scratch.refinenet4 = _make_fusion_block(features, use_bn) head_features_1 = features head_features_2 = 32 if nclass > 1: self.scratch.output_conv = nn.Sequential( nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0), ) else: self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) self.scratch.output_conv2 = nn.Sequential( nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True), nn.Identity(), ) def forward(self, out_features, patch_h, patch_w): out = [] for i, x in enumerate(out_features): if self.use_clstoken: x, cls_token = x[0], x[1] readout = cls_token.unsqueeze(1).expand_as(x) x = self.readout_projects[i](torch.cat((x, readout), -1)) else: x = x[0] x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) x = self.projects[i](x) x = self.resize_layers[i](x) out.append(x) layer_1, layer_2, layer_3, layer_4 = out # pylint: disable=unbalanced-tuple-unpacking layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv1(path_1) out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) out = self.scratch.output_conv2(out) return out class DPT_DINOv2(nn.Module): def __init__(self, encoder='vitl', features=256, out_channels=None, use_bn=False, use_clstoken=False, localhub=True): if out_channels is None: out_channels = [256, 512, 1024, 1024] super(DPT_DINOv2, self).__init__() assert encoder in ['vits', 'vitb', 'vitl'] # in case the Internet connection is not stable, please load the DINOv2 locally if localhub: self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) # pylint: disable=consider-using-f-string else: self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder)) # pylint: disable=consider-using-f-string dim = self.pretrained.blocks[0].attn.qkv.in_features self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) def forward(self, x): h, w = x.shape[-2:] features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) patch_h, patch_w = h // 14, w // 14 depth = self.depth_head(features, patch_h, patch_w) depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) depth = F.relu(depth) return depth.squeeze(1) class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin): def __init__(self, config): super().__init__(**config) ================================================ FILE: modules/control/proc/depth_anything/util/transform.py ================================================ import random from PIL import Image, ImageOps, ImageFilter import torch from torchvision import transforms import torch.nn.functional as F import numpy as np import cv2 import math def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. Args: sample (dict): sample size (tuple): image size Returns: tuple: new size """ shape = list(sample["disparity"].shape) if shape[0] >= size[0] and shape[1] >= size[1]: return sample scale = [0, 0] scale[0] = size[0] / shape[0] scale[1] = size[1] / shape[1] scale = max(scale) shape[0] = math.ceil(scale * shape[0]) shape[1] = math.ceil(scale * shape[1]) # resize sample["image"] = cv2.resize( sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method ) sample["disparity"] = cv2.resize( sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return tuple(shape) class Resize(object): """Resize sample to given size (width, height). """ def __init__( self, width, height, resize_target=True, keep_aspect_ratio=False, ensure_multiple_of=1, resize_method="lower_bound", image_interpolation_method=cv2.INTER_AREA, ): """Init. Args: width (int): desired output width height (int): desired output height resize_target (bool, optional): True: Resize the full sample (image, mask, target). False: Resize image only. Defaults to True. keep_aspect_ratio (bool, optional): True: Keep the aspect ratio of the input sample. Output sample might not have the given width and height, and resize behaviour depends on the parameter 'resize_method'. Defaults to False. ensure_multiple_of (int, optional): Output width and height is constrained to be multiple of this parameter. Defaults to 1. resize_method (str, optional): "lower_bound": Output will be at least as large as the given size. "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) "minimal": Scale as least as possible. (Output size might be smaller than given size.) Defaults to "lower_bound". """ self.__width = width self.__height = height self.__resize_target = resize_target self.__keep_aspect_ratio = keep_aspect_ratio self.__multiple_of = ensure_multiple_of self.__resize_method = resize_method self.__image_interpolation_method = image_interpolation_method def constrain_to_multiple_of(self, x, min_val=0, max_val=None): y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) if max_val is not None and y > max_val: y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) if y < min_val: y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) return y def get_size(self, width, height): # determine new height and width scale_height = self.__height / height scale_width = self.__width / width if self.__keep_aspect_ratio: if self.__resize_method == "lower_bound": # scale such that output size is lower bound if scale_width > scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "upper_bound": # scale such that output size is upper bound if scale_width < scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "minimal": # scale as least as possbile if abs(1 - scale_width) < abs(1 - scale_height): # fit width scale_height = scale_width else: # fit height scale_width = scale_height else: raise ValueError( f"resize_method {self.__resize_method} not implemented" ) if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of( scale_height * height, min_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, min_val=self.__width ) elif self.__resize_method == "upper_bound": new_height = self.constrain_to_multiple_of( scale_height * height, max_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, max_val=self.__width ) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: raise ValueError(f"resize_method {self.__resize_method} not implemented") return (new_width, new_height) def __call__(self, sample): width, height = self.get_size( sample["image"].shape[1], sample["image"].shape[0] ) # resize sample sample["image"] = cv2.resize( sample["image"], (width, height), interpolation=self.__image_interpolation_method, ) if self.__resize_target: if "disparity" in sample: sample["disparity"] = cv2.resize( sample["disparity"], (width, height), interpolation=cv2.INTER_NEAREST, ) if "depth" in sample: sample["depth"] = cv2.resize( sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST ) if "semseg_mask" in sample: # sample["semseg_mask"] = cv2.resize( # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST # ) sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0] if "mask" in sample: sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST, ) # sample["mask"] = sample["mask"].astype(bool) # print(sample['image'].shape, sample['depth'].shape) return sample class NormalizeImage(object): """Normlize image by given mean and std. """ def __init__(self, mean, std): self.__mean = mean self.__std = std def __call__(self, sample): sample["image"] = (sample["image"] - self.__mean) / self.__std return sample class PrepareForNet(object): """Prepare sample for usage as network input. """ def __init__(self): pass def __call__(self, sample): image = np.transpose(sample["image"], (2, 0, 1)) sample["image"] = np.ascontiguousarray(image).astype(np.float32) if "mask" in sample: sample["mask"] = sample["mask"].astype(np.float32) sample["mask"] = np.ascontiguousarray(sample["mask"]) if "depth" in sample: depth = sample["depth"].astype(np.float32) sample["depth"] = np.ascontiguousarray(depth) if "semseg_mask" in sample: sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32) sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"]) return sample ================================================ FILE: modules/control/proc/depth_pro/__init__.py ================================================ import cv2 import torch import torch.nn.functional as F import numpy as np from PIL import Image from modules import devices, masking from modules.shared import opts class DepthProDetector: """Apple DepthPro detector (aligned with Depth Anything style).""" def __init__(self, model, processor): self.model = model self.processor = processor @classmethod def from_pretrained(cls, pretrained_model_or_path: str = "apple/DepthPro-hf", cache_dir: str = None, local_files_only = False) -> "DepthProDetector": from transformers import AutoImageProcessor, DepthProForDepthEstimation processor = AutoImageProcessor.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only) model = DepthProForDepthEstimation.from_pretrained( pretrained_model_or_path, cache_dir=cache_dir, local_files_only=local_files_only, ).to(devices.device).eval() return cls(model, processor) def __call__(self, image, color_map: str = "none", output_type: str = "pil"): self.model.to(devices.device) if isinstance(image, Image.Image): image = np.array(image) h, w = image.shape[:2] image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image_rgb) inputs = self.processor(images=pil_image, return_tensors="pt") inputs = {k: v.to(devices.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} with devices.inference_context(): outputs = self.model(**inputs) results = self.processor.post_process_depth_estimation(outputs, target_sizes=[(h, w)]) depth_tensor = results[0]["predicted_depth"].to(devices.device, dtype=torch.float32) if opts.control_move_processor: self.model.to("cpu") depth_tensor = F.interpolate(depth_tensor[None, None], size=(h, w), mode="bilinear", align_corners=False)[0, 0] depth_tensor = 1.0 / torch.clamp(depth_tensor, min=1e-6) depth_tensor -= depth_tensor.min() depth_max = depth_tensor.max() if depth_max > 0: depth_tensor /= depth_max depth = (depth_tensor * 255.0).clamp(0, 255).to(torch.uint8).cpu().numpy() if color_map != "none": colormap_key = color_map if color_map in masking.COLORMAP else "inferno" depth = cv2.applyColorMap(depth, masking.COLORMAP.index(colormap_key))[:, :, ::-1] if output_type == "pil": mode = "RGB" if depth.ndim == 3 else "L" depth = Image.fromarray(depth, mode=mode) return depth ================================================ FILE: modules/control/proc/dpt.py ================================================ from PIL import Image import numpy as np import torch from transformers import AutoImageProcessor, DPTForDepthEstimation from modules import devices from modules.shared import opts image_processor: AutoImageProcessor = None class DPTDetector: def __init__(self, model=None, processor=None, model_path=None): self.model = model self.processor = processor self.model_path = model_path or "Intel/dpt-large" def __call__(self, input_image=None, model_path=None): from modules.control.processors import cache_dir if model_path is not None and model_path != self.model_path: self.model_path = model_path self.processor = None self.model = None if self.processor is None: self.processor = AutoImageProcessor.from_pretrained(self.model_path, cache_dir=cache_dir) if self.model is None: self.model = DPTForDepthEstimation.from_pretrained(self.model_path, cache_dir=cache_dir) self.model.to(devices.device) with devices.inference_context(): inputs = self.processor(images=input_image, return_tensors="pt") inputs.to(devices.device) outputs = self.model(**inputs) predicted_depth = outputs.predicted_depth prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=input_image.size[::-1], mode="bicubic", align_corners=False, ) output = prediction.squeeze().cpu().numpy() formatted = (output * 255 / np.max(output)).astype("uint8") if opts.control_move_processor: self.model.to('cpu') depth = Image.fromarray(formatted) depth = depth.convert('RGB') return depth ================================================ FILE: modules/control/proc/dwpose/__init__.py ================================================ # Openpose # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose # 2nd Edited by https://github.com/Hzzone/pytorch-openpose # 3rd Edited by ControlNet # 4th Edited by ControlNet (added face and correct hands) from typing import Type, Optional, Union, List import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" import cv2 import numpy as np from PIL import Image from installer import installed, pip, log from modules.control.util import HWC3, resize_image from .draw import draw_bodypose, draw_handpose, draw_facepose checked_ok = False busy = False def _register_module(self, module: Type, module_name: Optional[Union[str, List[str]]] = None, force: bool = False) -> None: if not callable(module): raise TypeError(f'module must be Callable, but got {type(module)}') if module_name is None: module_name = module.__name__ if isinstance(module_name, str): module_name = [module_name] for name in module_name: if not force and name in self._module_dict: # pylint: disable=protected-access pass # patch for 'Adafactor is already registered in optimizer at torch.optim' self._module_dict[name] = module # pylint: disable=protected-access def check_dependencies(): global checked_ok, busy # pylint: disable=global-statement busy = True debug = log.trace if os.environ.get('SD_DWPOSE_DEBUG', None) is not None else lambda *args, **kwargs: None # pip install --upgrade --no-deps --force-reinstall termcolor xtcocotools terminaltables pycocotools munkres shapely openmim==0.3.9 mmengine==0.10.5 mmcv==2.1.0 mmpose==1.3.2 mmdet==3.3.0 packages = [ 'termcolor', 'xtcocotools', 'terminaltables', 'pycocotools', 'munkres', 'shapely', 'openmim==0.3.9', 'mmengine==0.10.5', 'mmcv==2.1.0', 'mmpose==1.3.2', 'mmdet==3.3.0', ] status = [installed(p, reload=False, quiet=True) for p in packages] debug(f'DWPose required={packages} status={status}') if not all(status): log.info(f'Installing dependencies: for=dwpose packages={packages}') cmd = 'install --upgrade --no-deps --force-reinstall ' pkgs = ' '.join(packages) pip(cmd + pkgs, ignore=False, quiet=True, uv=False) try: import pkg_resources import imp # pylint: disable=deprecated-module imp.reload(pkg_resources) import mmcv # pylint: disable=unused-import import mmengine # pylint: disable=unused-import from mmengine.registry import Registry Registry._register_module = _register_module # pylint: disable=protected-access import mmpose # pylint: disable=unused-import import mmdet # pylint: disable=unused-import debug('DWPose import ok') checked_ok = True except Exception as e: log.error(f'DWPose: {e}') # from modules import errors # errors.display(e, 'DWPose') busy = False return checked_ok def draw_pose(pose, H, W): bodies = pose['bodies'] faces = pose['faces'] hands = pose['hands'] candidate = bodies['candidate'] subset = bodies['subset'] canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) canvas = draw_bodypose(canvas, candidate, subset) canvas = draw_handpose(canvas, hands) canvas = draw_facepose(canvas, faces) return canvas class DWposeDetector: def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"): self.pose_estimation = None if not checked_ok: if not check_dependencies(): return Wholebody = None try: from .wholebody import Wholebody except Exception as e: log.error(f'DWPose: {e}') if Wholebody is not None: self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device) def to(self, device): self.pose_estimation.to(device) return self def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", min_confidence=0.3, **kwargs): if self.pose_estimation is None: log.error("DWPose: not loaded") return None input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) H, W, _C = input_image.shape candidate, subset = self.pose_estimation(input_image) if candidate is None: return Image.fromarray(input_image) nums, _keys, locs = candidate.shape candidate[..., 0] /= float(W) candidate[..., 1] /= float(H) body = candidate[:,:18].copy() body = body.reshape(nums*18, locs) score = subset[:,:18] for i in range(len(score)): for j in range(len(score[i])): if score[i][j] > min_confidence: score[i][j] = int(18*i+j) else: score[i][j] = -1 un_visible = subset < min_confidence candidate[un_visible] = -1 _foot = candidate[:,18:24] faces = candidate[:,24:92] hands = candidate[:,92:113] hands = np.vstack([hands, candidate[:,113:]]) bodies = dict(candidate=body, subset=score) pose = dict(bodies=bodies, hands=hands, faces=faces) detected_map = draw_pose(pose, H, W) detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/dwpose/config/dwpose-l_384x288.py ================================================ # runtime max_epochs = 270 stage2_num_epochs = 30 base_lr = 4e-3 train_cfg = dict(max_epochs=max_epochs, val_interval=10) randomness = dict(seed=21) # optimizer optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), paramwise_cfg=dict( norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) # learning rate param_scheduler = [ dict( type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=1000), dict( # use cosine lr from 150 to 300 epoch type='CosineAnnealingLR', eta_min=base_lr * 0.05, begin=max_epochs // 2, end=max_epochs, T_max=max_epochs // 2, by_epoch=True, convert_to_iter_based=True), ] # automatically scaling LR based on the actual training batch size auto_scale_lr = dict(base_batch_size=512) # codec settings codec = dict( type='SimCCLabel', input_size=(288, 384), sigma=(6., 6.93), simcc_split_ratio=2.0, normalize=False, use_dark=False) # model settings model = dict( type='TopdownPoseEstimator', data_preprocessor=dict( type='PoseDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True), backbone=dict( _scope_='mmdet', type='CSPNeXt', arch='P5', expand_ratio=0.5, deepen_factor=1., widen_factor=1., out_indices=(4, ), channel_attention=True, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='SiLU'), init_cfg=dict( type='Pretrained', prefix='backbone.', checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' )), head=dict( type='RTMCCHead', in_channels=1024, out_channels=133, input_size=codec['input_size'], in_featuremap_size=(9, 12), simcc_split_ratio=codec['simcc_split_ratio'], final_layer_kernel_size=7, gau_cfg=dict( hidden_dims=256, s=128, expansion_factor=2, dropout_rate=0., drop_path=0., act_fn='SiLU', use_rel_bias=False, pos_enc=False), loss=dict( type='KLDiscretLoss', use_target_weight=True, beta=10., label_softmax=True), decoder=codec), test_cfg=dict(flip_test=True, )) # base dataset settings dataset_type = 'CocoWholeBodyDataset' data_mode = 'topdown' data_root = '/data/' backend_args = dict(backend='local') # backend_args = dict( # backend='petrel', # path_mapping=dict({ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/', # f'{data_root}': 's3://openmmlab/datasets/detection/coco/' # })) # pipelines train_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=1.0), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] val_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='PackPoseInputs') ] train_pipeline_stage2 = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', shift_factor=0., scale_factor=[0.75, 1.25], rotate_factor=60), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=0.5), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] datasets = [] dataset_coco=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_train_v1.0.json', data_prefix=dict(img='coco/train2017/'), pipeline=[], ) datasets.append(dataset_coco) scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] for i in range(len(scene)): datasets.append( dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', data_prefix=dict(img='UBody/images/'+scene[i]+'/'), pipeline=[], ) ) # data loaders train_dataloader = dict( batch_size=32, num_workers=10, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type='CombinedDataset', metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), datasets=datasets, pipeline=train_pipeline, test_mode=False, )) val_dataloader = dict( batch_size=32, num_workers=10, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), dataset=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_val_v1.0.json', bbox_file=f'{data_root}coco/person_detection_results/' 'COCO_val2017_detections_AP_H_56_person.json', data_prefix=dict(img='coco/val2017/'), test_mode=True, pipeline=val_pipeline, )) test_dataloader = val_dataloader # hooks default_hooks = dict( checkpoint=dict( save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0002, update_buffers=True, priority=49), dict( type='mmdet.PipelineSwitchHook', switch_epoch=max_epochs - stage2_num_epochs, switch_pipeline=train_pipeline_stage2) ] # evaluators val_evaluator = dict( type='CocoWholeBodyMetric', ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') test_evaluator = val_evaluator ================================================ FILE: modules/control/proc/dwpose/config/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py ================================================ # _base_ = ['../../../_base_/default_runtime.py'] # runtime max_epochs = 270 stage2_num_epochs = 30 base_lr = 4e-3 train_cfg = dict(max_epochs=max_epochs, val_interval=10) randomness = dict(seed=21) # optimizer optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), paramwise_cfg=dict( norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) # learning rate param_scheduler = [ dict( type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=1000), dict( # use cosine lr from 150 to 300 epoch type='CosineAnnealingLR', eta_min=base_lr * 0.05, begin=max_epochs // 2, end=max_epochs, T_max=max_epochs // 2, by_epoch=True, convert_to_iter_based=True), ] # automatically scaling LR based on the actual training batch size auto_scale_lr = dict(base_batch_size=512) # codec settings codec = dict( type='SimCCLabel', input_size=(288, 384), sigma=(6., 6.93), simcc_split_ratio=2.0, normalize=False, use_dark=False) # model settings model = dict( type='TopdownPoseEstimator', data_preprocessor=dict( type='PoseDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True), backbone=dict( _scope_='mmdet', type='CSPNeXt', arch='P5', expand_ratio=0.5, deepen_factor=1., widen_factor=1., out_indices=(4, ), channel_attention=True, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='SiLU'), init_cfg=dict( type='Pretrained', prefix='backbone.', checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' )), head=dict( type='RTMCCHead', in_channels=1024, out_channels=133, input_size=codec['input_size'], in_featuremap_size=(9, 12), simcc_split_ratio=codec['simcc_split_ratio'], final_layer_kernel_size=7, gau_cfg=dict( hidden_dims=256, s=128, expansion_factor=2, dropout_rate=0., drop_path=0., act_fn='SiLU', use_rel_bias=False, pos_enc=False), loss=dict( type='KLDiscretLoss', use_target_weight=True, beta=10., label_softmax=True), decoder=codec), test_cfg=dict(flip_test=True, )) # base dataset settings dataset_type = 'CocoWholeBodyDataset' data_mode = 'topdown' data_root = 'data/' backend_args = dict(backend='local') # backend_args = dict( # backend='petrel', # path_mapping=dict({ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/', # f'{data_root}': 's3://openmmlab/datasets/detection/coco/' # })) # pipelines train_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=1.0), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] val_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='PackPoseInputs') ] train_pipeline_stage2 = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', shift_factor=0., scale_factor=[0.75, 1.25], rotate_factor=60), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=0.5), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] datasets = [] dataset_coco=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_train_v1.0.json', data_prefix=dict(img='coco/train2017/'), pipeline=[], ) datasets.append(dataset_coco) scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] for i in range(len(scene)): datasets.append( dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', data_prefix=dict(img='UBody/images/'+scene[i]+'/'), pipeline=[], ) ) # data loaders train_dataloader = dict( batch_size=32, num_workers=10, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type='CombinedDataset', metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), datasets=datasets, pipeline=train_pipeline, test_mode=False, )) val_dataloader = dict( batch_size=32, num_workers=10, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), dataset=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_val_v1.0.json', bbox_file=f'{data_root}coco/person_detection_results/' 'COCO_val2017_detections_AP_H_56_person.json', data_prefix=dict(img='coco/val2017/'), test_mode=True, pipeline=val_pipeline, )) test_dataloader = val_dataloader # hooks default_hooks = dict( checkpoint=dict( save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0002, update_buffers=True, priority=49), dict( type='mmdet.PipelineSwitchHook', switch_epoch=max_epochs - stage2_num_epochs, switch_pipeline=train_pipeline_stage2) ] # evaluators val_evaluator = dict( type='CocoWholeBodyMetric', ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') test_evaluator = val_evaluator ================================================ FILE: modules/control/proc/dwpose/config/rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py ================================================ # _base_ = ['../../../_base_/default_runtime.py'] # runtime max_epochs = 270 stage2_num_epochs = 30 base_lr = 4e-3 train_cfg = dict(max_epochs=max_epochs, val_interval=10) randomness = dict(seed=21) # optimizer optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), paramwise_cfg=dict( norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) # learning rate param_scheduler = [ dict( type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=1000), dict( # use cosine lr from 150 to 300 epoch type='CosineAnnealingLR', eta_min=base_lr * 0.05, begin=max_epochs // 2, end=max_epochs, T_max=max_epochs // 2, by_epoch=True, convert_to_iter_based=True), ] # automatically scaling LR based on the actual training batch size auto_scale_lr = dict(base_batch_size=512) # codec settings codec = dict( type='SimCCLabel', input_size=(192, 256), sigma=(4.9, 5.66), simcc_split_ratio=2.0, normalize=False, use_dark=False) # model settings model = dict( type='TopdownPoseEstimator', data_preprocessor=dict( type='PoseDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True), backbone=dict( _scope_='mmdet', type='CSPNeXt', arch='P5', expand_ratio=0.5, deepen_factor=0.67, widen_factor=0.75, out_indices=(4, ), channel_attention=True, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='SiLU'), init_cfg=dict( type='Pretrained', prefix='backbone.', checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 'rtmpose/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth' )), head=dict( type='RTMCCHead', in_channels=768, out_channels=133, input_size=codec['input_size'], in_featuremap_size=(6, 8), simcc_split_ratio=codec['simcc_split_ratio'], final_layer_kernel_size=7, gau_cfg=dict( hidden_dims=256, s=128, expansion_factor=2, dropout_rate=0., drop_path=0., act_fn='SiLU', use_rel_bias=False, pos_enc=False), loss=dict( type='KLDiscretLoss', use_target_weight=True, beta=10., label_softmax=True), decoder=codec), test_cfg=dict(flip_test=True, )) # base dataset settings dataset_type = 'CocoWholeBodyDataset' data_mode = 'topdown' data_root = 'data/' backend_args = dict(backend='local') # backend_args = dict( # backend='petrel', # path_mapping=dict({ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/', # f'{data_root}': 's3://openmmlab/datasets/detection/coco/' # })) # pipelines train_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=1.0), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] val_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='PackPoseInputs') ] train_pipeline_stage2 = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', shift_factor=0., scale_factor=[0.75, 1.25], rotate_factor=60), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=0.5), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] datasets = [] dataset_coco=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_train_v1.0.json', data_prefix=dict(img='coco/train2017/'), pipeline=[], ) datasets.append(dataset_coco) scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] for i in range(len(scene)): datasets.append( dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', data_prefix=dict(img='UBody/images/'+scene[i]+'/'), pipeline=[], ) ) # data loaders train_dataloader = dict( batch_size=64, num_workers=10, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type='CombinedDataset', metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), datasets=datasets, pipeline=train_pipeline, test_mode=False, )) val_dataloader = dict( batch_size=32, num_workers=10, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), dataset=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_val_v1.0.json', bbox_file=f'{data_root}coco/person_detection_results/' 'COCO_val2017_detections_AP_H_56_person.json', data_prefix=dict(img='coco/val2017/'), test_mode=True, pipeline=val_pipeline, )) test_dataloader = val_dataloader # hooks default_hooks = dict( checkpoint=dict( save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0002, update_buffers=True, priority=49), dict( type='mmdet.PipelineSwitchHook', switch_epoch=max_epochs - stage2_num_epochs, switch_pipeline=train_pipeline_stage2) ] # evaluators val_evaluator = dict( type='CocoWholeBodyMetric', ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') test_evaluator = val_evaluator ================================================ FILE: modules/control/proc/dwpose/config/rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py ================================================ # _base_ = ['../../../_base_/default_runtime.py'] # runtime max_epochs = 270 stage2_num_epochs = 30 base_lr = 4e-3 train_cfg = dict(max_epochs=max_epochs, val_interval=10) randomness = dict(seed=21) # optimizer optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), paramwise_cfg=dict( norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) # learning rate param_scheduler = [ dict( type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=1000), dict( # use cosine lr from 150 to 300 epoch type='CosineAnnealingLR', eta_min=base_lr * 0.05, begin=max_epochs // 2, end=max_epochs, T_max=max_epochs // 2, by_epoch=True, convert_to_iter_based=True), ] # automatically scaling LR based on the actual training batch size auto_scale_lr = dict(base_batch_size=512) # codec settings codec = dict( type='SimCCLabel', input_size=(192, 256), sigma=(4.9, 5.66), simcc_split_ratio=2.0, normalize=False, use_dark=False) # model settings model = dict( type='TopdownPoseEstimator', data_preprocessor=dict( type='PoseDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True), backbone=dict( _scope_='mmdet', type='CSPNeXt', arch='P5', expand_ratio=0.5, deepen_factor=0.167, widen_factor=0.375, out_indices=(4, ), channel_attention=True, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='SiLU'), init_cfg=dict( type='Pretrained', prefix='backbone.', checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' 'rtmpose/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth' )), head=dict( type='RTMCCHead', in_channels=384, out_channels=133, input_size=codec['input_size'], in_featuremap_size=(6, 8), simcc_split_ratio=codec['simcc_split_ratio'], final_layer_kernel_size=7, gau_cfg=dict( hidden_dims=256, s=128, expansion_factor=2, dropout_rate=0., drop_path=0., act_fn='SiLU', use_rel_bias=False, pos_enc=False), loss=dict( type='KLDiscretLoss', use_target_weight=True, beta=10., label_softmax=True), decoder=codec), test_cfg=dict(flip_test=True, )) # base dataset settings dataset_type = 'CocoWholeBodyDataset' data_mode = 'topdown' data_root = 'data/' backend_args = dict(backend='local') # backend_args = dict( # backend='petrel', # path_mapping=dict({ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/', # f'{data_root}': 's3://openmmlab/datasets/detection/coco/' # })) # pipelines train_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=1.0), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] val_pipeline = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='PackPoseInputs') ] train_pipeline_stage2 = [ dict(type='LoadImage', backend_args=backend_args), dict(type='GetBBoxCenterScale'), dict(type='RandomFlip', direction='horizontal'), dict(type='RandomHalfBody'), dict( type='RandomBBoxTransform', shift_factor=0., scale_factor=[0.75, 1.25], rotate_factor=60), dict(type='TopdownAffine', input_size=codec['input_size']), dict(type='mmdet.YOLOXHSVRandomAug'), dict( type='Albumentation', transforms=[ dict(type='Blur', p=0.1), dict(type='MedianBlur', p=0.1), dict( type='CoarseDropout', max_holes=1, max_height=0.4, max_width=0.4, min_holes=1, min_height=0.2, min_width=0.2, p=0.5), ]), dict(type='GenerateTarget', encoder=codec), dict(type='PackPoseInputs') ] datasets = [] dataset_coco=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_train_v1.0.json', data_prefix=dict(img='coco/train2017/'), pipeline=[], ) datasets.append(dataset_coco) scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'] for i in range(len(scene)): datasets.append( dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json', data_prefix=dict(img='UBody/images/'+scene[i]+'/'), pipeline=[], ) ) # data loaders train_dataloader = dict( batch_size=64, num_workers=10, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( type='CombinedDataset', metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), datasets=datasets, pipeline=train_pipeline, test_mode=False, )) val_dataloader = dict( batch_size=32, num_workers=10, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), dataset=dict( type=dataset_type, data_root=data_root, data_mode=data_mode, ann_file='coco/annotations/coco_wholebody_val_v1.0.json', bbox_file=f'{data_root}coco/person_detection_results/' 'COCO_val2017_detections_AP_H_56_person.json', data_prefix=dict(img='coco/val2017/'), test_mode=True, pipeline=val_pipeline, )) test_dataloader = val_dataloader # hooks default_hooks = dict( checkpoint=dict( save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) custom_hooks = [ dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0002, update_buffers=True, priority=49), dict( type='mmdet.PipelineSwitchHook', switch_epoch=max_epochs - stage2_num_epochs, switch_pipeline=train_pipeline_stage2) ] # evaluators val_evaluator = dict( type='CocoWholeBodyMetric', ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json') test_evaluator = val_evaluator ================================================ FILE: modules/control/proc/dwpose/config/yolox_l_8xb8-300e_coco.py ================================================ img_scale = (640, 640) # width, height # model settings model = dict( type='YOLOX', data_preprocessor=dict( type='DetDataPreprocessor', pad_size_divisor=32, batch_augments=[ dict( type='BatchSyncRandomResize', random_size_range=(480, 800), size_divisor=32, interval=10) ]), backbone=dict( type='CSPDarknet', deepen_factor=1.0, widen_factor=1.0, out_indices=(2, 3, 4), use_depthwise=False, spp_kernal_sizes=(5, 9, 13), norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='Swish'), ), neck=dict( type='YOLOXPAFPN', in_channels=[256, 512, 1024], out_channels=256, num_csp_blocks=3, use_depthwise=False, upsample_cfg=dict(scale_factor=2, mode='nearest'), norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='Swish')), bbox_head=dict( type='YOLOXHead', num_classes=80, in_channels=256, feat_channels=256, stacked_convs=2, strides=(8, 16, 32), use_depthwise=False, norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='Swish'), loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='sum', loss_weight=1.0), loss_bbox=dict( type='IoULoss', mode='square', eps=1e-16, reduction='sum', loss_weight=5.0), loss_obj=dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='sum', loss_weight=1.0), loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)), train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), # In order to align the source code, the threshold of the val phase is # 0.01, and the threshold of the test phase is 0.001. test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) # dataset settings data_root = 'data/coco/' dataset_type = 'CocoDataset' # Example to use different file client # Method 1: simply set the data root and let the file I/O module # automatically infer from prefix (not support LMDB and Memcache yet) # data_root = 's3://openmmlab/datasets/detection/coco/' # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 # backend_args = dict( # backend='petrel', # path_mapping=dict({ # './data/': 's3://openmmlab/datasets/detection/', # 'data/': 's3://openmmlab/datasets/detection/' # })) backend_args = None train_pipeline = [ dict(type='Mosaic', img_scale=img_scale, pad_val=114.0), dict( type='RandomAffine', scaling_ratio_range=(0.1, 2), # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2)), dict( type='MixUp', img_scale=img_scale, ratio_range=(0.8, 1.6), pad_val=114.0), dict(type='YOLOXHSVRandomAug'), dict(type='RandomFlip', prob=0.5), # According to the official implementation, multi-scale # training is not considered here but in the # 'mmdet/models/detectors/yolox.py'. # Resize and Pad are for the last 15 epochs when Mosaic, # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook. dict(type='Resize', scale=img_scale, keep_ratio=True), dict( type='Pad', pad_to_square=True, # If the image is three-channel, the pad value needs # to be set separately for each channel. pad_val=dict(img=(114.0, 114.0, 114.0))), dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), dict(type='PackDetInputs') ] train_dataset = dict( # use MultiImageMixDataset wrapper to support mosaic and mixup type='MultiImageMixDataset', dataset=dict( type=dataset_type, data_root=data_root, ann_file='annotations/instances_train2017.json', data_prefix=dict(img='train2017/'), pipeline=[ dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='LoadAnnotations', with_bbox=True) ], filter_cfg=dict(filter_empty_gt=False, min_size=32), backend_args=backend_args), pipeline=train_pipeline) test_pipeline = [ dict(type='LoadImageFromFile', backend_args=backend_args), dict(type='Resize', scale=img_scale, keep_ratio=True), dict( type='Pad', pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0))), dict(type='LoadAnnotations', with_bbox=True), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) ] train_dataloader = dict( batch_size=8, num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=train_dataset) val_dataloader = dict( batch_size=8, num_workers=4, persistent_workers=True, drop_last=False, sampler=dict(type='DefaultSampler', shuffle=False), dataset=dict( type=dataset_type, data_root=data_root, ann_file='annotations/instances_val2017.json', data_prefix=dict(img='val2017/'), test_mode=True, pipeline=test_pipeline, backend_args=backend_args)) test_dataloader = val_dataloader val_evaluator = dict( type='CocoMetric', ann_file=data_root + 'annotations/instances_val2017.json', metric='bbox', backend_args=backend_args) test_evaluator = val_evaluator # training settings max_epochs = 300 num_last_epochs = 15 interval = 10 train_cfg = dict(max_epochs=max_epochs, val_interval=interval) # optimizer # default 8 gpu base_lr = 0.01 optim_wrapper = dict( type='OptimWrapper', optimizer=dict( type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4, nesterov=True), paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.)) # learning rate param_scheduler = [ dict( # use quadratic formula to warm up 5 epochs # and lr is updated by iteration type='mmdet.QuadraticWarmupLR', by_epoch=True, begin=0, end=5, convert_to_iter_based=True), dict( # use cosine lr from 5 to 285 epoch type='CosineAnnealingLR', eta_min=base_lr * 0.05, begin=5, T_max=max_epochs - num_last_epochs, end=max_epochs - num_last_epochs, by_epoch=True, convert_to_iter_based=True), dict( # use fixed lr during last 15 epochs type='ConstantLR', by_epoch=True, factor=1, begin=max_epochs - num_last_epochs, end=max_epochs, ) ] default_hooks = dict( checkpoint=dict( interval=interval, max_keep_ckpts=3 # only keep latest 3 checkpoints )) custom_hooks = [ dict( type='YOLOXModeSwitchHook', num_last_epochs=num_last_epochs, priority=48), dict(type='SyncNormHook', priority=48), dict( type='EMAHook', ema_type='ExpMomentumEMA', momentum=0.0001, update_buffers=True, priority=49) ] # NOTE: `auto_scale_lr` is for automatically scaling LR, # USER SHOULD NOT CHANGE ITS VALUES. # base_batch_size = (8 GPUs) x (8 samples per GPU) auto_scale_lr = dict(base_batch_size=64) ================================================ FILE: modules/control/proc/dwpose/draw.py ================================================ import math import numpy as np import cv2 eps = 0.01 def smart_resize(x, s): Ht, Wt = s if x.ndim == 2: Ho, Wo = x.shape Co = 1 else: Ho, Wo, Co = x.shape if Co == 3 or Co == 1: k = float(Ht + Wt) / float(Ho + Wo) return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) else: return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) def smart_resize_k(x, fx, fy): if x.ndim == 2: Ho, Wo = x.shape Co = 1 else: Ho, Wo, Co = x.shape Ht, Wt = Ho * fy, Wo * fx if Co == 3 or Co == 1: k = float(Ht + Wt) / float(Ho + Wo) return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) else: return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) def padRightDownCorner(img, stride, padValue): h = img.shape[0] w = img.shape[1] pad = 4 * [None] pad[0] = 0 # up pad[1] = 0 # left pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right img_padded = img pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) img_padded = np.concatenate((pad_up, img_padded), axis=0) pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) img_padded = np.concatenate((pad_left, img_padded), axis=1) pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) img_padded = np.concatenate((img_padded, pad_down), axis=0) pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) img_padded = np.concatenate((img_padded, pad_right), axis=1) return img_padded, pad def transfer(model, model_weights): transfered_model_weights = {} for weights_name in model.state_dict().keys(): transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] return transfered_model_weights def draw_bodypose(canvas, candidate, subset): H, W, _C = canvas.shape candidate = np.array(candidate) subset = np.array(subset) stickwidth = 4 limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ [1, 16], [16, 18], [3, 17], [6, 18]] colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] for i in range(17): for n in range(len(subset)): index = subset[n][np.array(limbSeq[i]) - 1] if -1 in index: continue Y = candidate[index.astype(int), 0] * float(W) X = candidate[index.astype(int), 1] * float(H) mX = np.mean(X) mY = np.mean(Y) length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) cv2.fillConvexPoly(canvas, polygon, colors[i]) canvas = (canvas * 0.6).astype(np.uint8) for i in range(18): for n in range(len(subset)): index = int(subset[n][i]) if index == -1: continue x, y = candidate[index][0:2] x = int(x * W) y = int(y * H) cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) return canvas def draw_handpose(canvas, all_hand_peaks): import matplotlib as mpl H, W, _C = canvas.shape edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] # (person_number*2, 21, 2) for i in range(len(all_hand_peaks)): peaks = all_hand_peaks[i] peaks = np.array(peaks) for ie, e in enumerate(edges): x1, y1 = peaks[e[0]] x2, y2 = peaks[e[1]] x1 = int(x1 * W) y1 = int(y1 * H) x2 = int(x2 * W) y2 = int(y2 * H) if x1 > eps and y1 > eps and x2 > eps and y2 > eps: cv2.line(canvas, (x1, y1), (x2, y2), mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) for _, keyponit in enumerate(peaks): x, y = keyponit x = int(x * W) y = int(y * H) if x > eps and y > eps: cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) return canvas def draw_facepose(canvas, all_lmks): H, W, _C = canvas.shape for lmks in all_lmks: lmks = np.array(lmks) for lmk in lmks: x, y = lmk x = int(x * W) y = int(y * H) if x > eps and y > eps: cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) return canvas # detect hand according to body pose keypoints # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp def handDetect(candidate, subset, oriImg): # right hand: wrist 4, elbow 3, shoulder 2 # left hand: wrist 7, elbow 6, shoulder 5 ratioWristElbow = 0.33 detect_result = [] image_height, image_width = oriImg.shape[0:2] for person in subset.astype(int): # if any of three not detected has_left = np.sum(person[[5, 6, 7]] == -1) == 0 has_right = np.sum(person[[2, 3, 4]] == -1) == 0 if not (has_left or has_right): continue hands = [] #left hand if has_left: left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] x1, y1 = candidate[left_shoulder_index][:2] x2, y2 = candidate[left_elbow_index][:2] x3, y3 = candidate[left_wrist_index][:2] hands.append([x1, y1, x2, y2, x3, y3, True]) # right hand if has_right: right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] x1, y1 = candidate[right_shoulder_index][:2] x2, y2 = candidate[right_elbow_index][:2] x3, y3 = candidate[right_wrist_index][:2] hands.append([x1, y1, x2, y2, x3, y3, False]) for x1, y1, x2, y2, x3, y3, is_left in hands: # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); x = x3 + ratioWristElbow * (x3 - x2) y = y3 + ratioWristElbow * (y3 - y2) distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) # x-y refers to the center --> offset to topLeft point # handRectangle.x -= handRectangle.width / 2.f; # handRectangle.y -= handRectangle.height / 2.f; x -= width / 2 y -= width / 2 # width = height # overflow the image if x < 0: x = 0 if y < 0: y = 0 width1 = width width2 = width if x + width > image_width: width1 = image_width - x if y + width > image_height: width2 = image_height - y width = min(width1, width2) # the max hand box value is 20 pixels if width >= 20: detect_result.append([int(x), int(y), int(width), is_left]) ''' return value: [[x, y, w, True if left hand else False]]. width=height since the network require squared input. x, y is the coordinate of top left ''' return detect_result # Written by Lvmin def faceDetect(candidate, subset, oriImg): # left right eye ear 14 15 16 17 detect_result = [] image_height, image_width = oriImg.shape[0:2] for person in subset.astype(int): has_head = person[0] > -1 if not has_head: continue has_left_eye = person[14] > -1 has_right_eye = person[15] > -1 has_left_ear = person[16] > -1 has_right_ear = person[17] > -1 if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): continue head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] width = 0.0 x0, y0 = candidate[head][:2] if has_left_eye: x1, y1 = candidate[left_eye][:2] d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 3.0) if has_right_eye: x1, y1 = candidate[right_eye][:2] d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 3.0) if has_left_ear: x1, y1 = candidate[left_ear][:2] d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 1.5) if has_right_ear: x1, y1 = candidate[right_ear][:2] d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 1.5) x, y = x0, y0 x -= width y -= width if x < 0: x = 0 if y < 0: y = 0 width1 = width * 2 width2 = width * 2 if x + width > image_width: width1 = image_width - x if y + width > image_height: width2 = image_height - y width = min(width1, width2) if width >= 20: detect_result.append([int(x), int(y), int(width)]) return detect_result # get max index of 2d array def npmax(array): arrayindex = array.argmax(1) arrayvalue = array.max(1) i = arrayvalue.argmax() j = arrayindex[i] return i, j ================================================ FILE: modules/control/proc/dwpose/wholebody.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. import os import numpy as np from modules.shared import log mmok = True try: import mmcv # pylint: disable=unused-import except ImportError as e: mmok = False log.error(f"Control processor DWPose: {e}") try: from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms from mmpose.utils import adapt_mmdet_pipeline from mmpose.structures import merge_data_samples except ImportError as e: mmok = False log.error(f"Control processor DWPose: {e}") try: from mmdet.apis import inference_detector, init_detector except ImportError as e: mmok = False log.error(f"Control processor DWPose: {e}") def inference_detector(*args, **kwargs): return lambda *args, **kwargs: None if not mmok: log.error('Control processor DWPose: OpenMMLab is not installed') class Wholebody: def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"): if not mmok: self.detector = lambda *args, **kwargs: None return None prefix = os.path.dirname(__file__) if det_config is None: det_config = "config/yolox_l_8xb8-300e_coco.py" if pose_config is None: pose_config = "config/dwpose-l_384x288.py" if not det_config.startswith('prefix'): det_config = os.path.join(prefix, det_config) if not pose_config.startswith('prefix'): pose_config = os.path.join(prefix, pose_config) if det_ckpt is None: det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' if pose_ckpt is None: pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth" # build detector self.detector = init_detector(det_config, det_ckpt, device=device) self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg) # build pose estimator self.pose_estimator = init_pose_estimator( pose_config, pose_ckpt, device=device) def to(self, device): self.detector.to(device) self.pose_estimator.to(device) return self def __call__(self, oriImg): if not mmok: return None, None # predict bbox det_result = inference_detector(self.detector, oriImg) pred_instance = det_result.pred_instances.cpu().numpy() bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > 0.5)] # set NMS threshold bboxes = bboxes[nms(bboxes, 0.7), :4] # predict keypoints if len(bboxes) == 0: pose_results = inference_topdown(self.pose_estimator, oriImg) else: pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes) preds = merge_data_samples(pose_results) preds = preds.pred_instances # preds = pose_results[0].pred_instances keypoints = preds.get('transformed_keypoints', preds.keypoints) if 'keypoint_scores' in preds: scores = preds.keypoint_scores else: scores = np.ones(keypoints.shape[:-1]) if 'keypoints_visible' in preds: visible = preds.keypoints_visible else: visible = np.ones(keypoints.shape[:-1]) keypoints_info = np.concatenate( (keypoints, scores[..., None], visible[..., None]), axis=-1) # compute neck joint neck = np.mean(keypoints_info[:, [5, 6]], axis=1) # neck score when visualizing pred neck[:, 2:4] = np.logical_and( keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int) new_keypoints_info = np.insert( keypoints_info, 17, neck, axis=1) mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3] openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17] new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx] keypoints_info = new_keypoints_info keypoints, scores, visible = keypoints_info[..., :2], keypoints_info[..., 2], keypoints_info[..., 3] return keypoints, scores ================================================ FILE: modules/control/proc/edge.py ================================================ import warnings import cv2 import numpy as np from PIL import Image from modules.control.util import HWC3, resize_image ed = None """ PFmode: bool EdgeDetectionOperator: int GradientThresholdValue: int AnchorThresholdValue: int ScanInterval: int MinPathLength: int Sigma: float SumFlag: bool NFAValidation: bool MinLineLength: int MaxDistanceBetweenTwoLines: float LineFitErrorThreshold: float MaxErrorThreshold: float """ class EdgeDetector: def __call__(self, input_image=None, pf=True, mode='edge', detect_resolution=512, image_resolution=512, output_type=None, **kwargs): global ed # pylint: disable=global-statement if ed is None: try: ed = cv2.ximgproc.createEdgeDrawing() except Exception as e: raise ImportError("Edge processor: invalid version of OpenCV found") from e params = cv2.ximgproc.EdgeDrawing.Params() params.PFmode = pf ed.setParams(params) if "img" in kwargs: warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) input_image = kwargs.pop("img") if input_image is None: raise ValueError("input_image must be defined.") if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) output_type = output_type or "pil" else: output_type = output_type or "np" input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) img_gray = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY) edges = ed.detectEdges(img_gray) if mode == 'edge': edge_map = ed.getEdgeImage(edges) else: edge_map = ed.getGradientImage(edges) edge_map = np.expand_dims(edge_map, axis=2) edge_map = cv2.cvtColor(edge_map, cv2.COLOR_GRAY2BGR).astype(np.uint8) edge_map = HWC3(edge_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape edge_map = cv2.resize(edge_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": edge_map = Image.fromarray(edge_map) return edge_map ================================================ FILE: modules/control/proc/glpn.py ================================================ from PIL import Image import numpy as np import torch from transformers import AutoImageProcessor, GLPNForDepthEstimation from modules import devices from modules.shared import opts class GLPNDetector: def __init__(self, model=None, processor=None): self.model = model self.processor = processor def __call__(self, input_image=None): from modules.control.processors import cache_dir if self.processor is None: self.processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti", cache_dir=cache_dir) if self.model is None: self.model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti", cache_dir=cache_dir) self.model.to(devices.device) with devices.inference_context(): inputs = self.processor(images=input_image, return_tensors="pt") inputs.to(devices.device) outputs = self.model(**inputs) predicted_depth = outputs.predicted_depth prediction = torch.nn.functional.interpolate( predicted_depth.unsqueeze(1), size=input_image.size[::-1], mode="bicubic", align_corners=False, ) output = prediction.squeeze().cpu().numpy() formatted = 255 - (output * 255 / np.max(output)).astype("uint8") if opts.control_move_processor: self.model.to('cpu') depth = Image.fromarray(formatted) depth = depth.convert('RGB') return depth ================================================ FILE: modules/control/proc/hed.py ================================================ # This is an improved version and model of HED edge detection with Apache License, Version 2.0. # Please use this implementation in your products # This implementation may produce slightly different results from Saining Xie's official implementations, # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. # Different from official models and other implementations, this is an RGB-input model (rather than BGR) # and in this way it works better for gradio's RGB protocol import os import cv2 import numpy as np import torch from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, nms, resize_image, safe_step class DoubleConvBlock(torch.nn.Module): # pylint: disable=abstract-method def __init__(self, input_channel, output_channel, layer_number): super().__init__() self.convs = torch.nn.Sequential() self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) for _i in range(1, layer_number): self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) def __call__(self, x, down_sampling=False): h = x if down_sampling: h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) for conv in self.convs: h = conv(h) h = torch.nn.functional.relu(h) return h, self.projection(h) class ControlNetHED_Apache2(torch.nn.Module): # pylint: disable=abstract-method def __init__(self): super().__init__() self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) def __call__(self, x): h = x - self.norm h, projection1 = self.block1(h) h, projection2 = self.block2(h, down_sampling=True) h, projection3 = self.block3(h, down_sampling=True) h, projection4 = self.block4(h, down_sampling=True) h, projection5 = self.block5(h, down_sampling=True) return projection1, projection2, projection3, projection4, projection5 class HEDdetector: def __init__(self, model): self.model = model @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): filename = filename or "ControlNetHED.pth" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) model = ControlNetHED_Apache2() model.load_state_dict(torch.load(model_path, map_location='cpu')) model.float().eval() return cls(model) def to(self, device): self.model.to(device) return self def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, **kwargs): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 H, W, _C = input_image.shape image_hed = torch.from_numpy(input_image.copy()).float().to(device) image_hed = rearrange(image_hed, 'h w c -> 1 c h w') edges = self.model(image_hed) edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] edges = np.stack(edges, axis=2) edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) if safe: edge = safe_step(edge) edge = (edge * 255.0).clip(0, 255).astype(np.uint8) detected_map = edge detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if scribble: detected_map = nms(detected_map, 127, 3.0) detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) detected_map[detected_map > 4] = 255 detected_map[detected_map < 255] = 0 if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/leres/__init__.py ================================================ import os import cv2 import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image from .leres.depthmap import estimateboost, estimateleres from .leres.multi_depth_model_woauxi import RelDepthModel from .leres.net_tools import strip_prefix_if_present from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel from .pix2pix.options.test_options import TestOptions class LeresDetector: def __init__(self, model, pix2pixmodel): self.model = model self.pix2pixmodel = pix2pixmodel @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None, local_files_only=False): filename = filename or "res101.pth" pix2pix_filename = pix2pix_filename or "latest_net_G.pth" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) checkpoint = torch.load(model_path, map_location=torch.device('cpu')) model = RelDepthModel(backbone='resnext101') model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True) del checkpoint if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, pix2pix_filename) else: model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir, local_files_only=local_files_only) opt = TestOptions().parse() if not torch.cuda.is_available(): opt.gpu_ids = [] # cpu mode pix2pixmodel = Pix2Pix4DepthModel(opt) pix2pixmodel.save_dir = os.path.dirname(model_path) pix2pixmodel.load_networks('latest') pix2pixmodel.eval() return cls(model, pix2pixmodel) def to(self, device): self.model.to(device) return self def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"): self.model.to(devices.device) # device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 height, width, _dim = input_image.shape if boost: depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height)) else: depth = estimateleres(input_image, self.model, width, height) numbytes=2 depth_min = depth.min() depth_max = depth.max() max_val = (2**(8*numbytes))-1 # check output before normalizing and mapping to 16 bit if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) else: out = np.zeros(depth.shape) # single channel, 16 bit image depth_image = out.astype("uint16") # convert to uint8 depth_image = cv2.convertScaleAbs(depth_image, alpha=255.0/65535.0) # remove near if thr_a != 0: thr_a = thr_a/100*255 depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1] # invert image depth_image = cv2.bitwise_not(depth_image) # remove bg if thr_b != 0: thr_b = thr_b/100*255 depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1] detected_map = depth_image detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/leres/leres/LICENSE ================================================ https://github.com/thygate/stable-diffusion-webui-depthmap-script MIT License Copyright (c) 2023 Bob Thiry Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: modules/control/proc/leres/leres/Resnet.py ================================================ import torch.nn as nn import torch.nn as NN __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) #self.avgpool = nn.AvgPool2d(7, stride=1) #self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for _i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): features = [] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) features.append(x) x = self.layer2(x) features.append(x) x = self.layer3(x) features.append(x) x = self.layer4(x) features.append(x) return features def resnet18(pretrained=True, **kwargs): """Constructs a ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) return model def resnet34(pretrained=True, **kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) return model def resnet50(pretrained=True, **kwargs): """Constructs a ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) return model def resnet101(pretrained=True, **kwargs): """Constructs a ResNet-101 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) return model def resnet152(pretrained=True, **kwargs): """Constructs a ResNet-152 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) return model ================================================ FILE: modules/control/proc/leres/leres/Resnext_torch.py ================================================ #!/usr/bin/env python # coding: utf-8 import torch.nn as nn try: from urllib import urlretrieve except ImportError: from urllib.request import urlretrieve __all__ = ['resnext101_32x8d'] model_urls = { 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', } def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) #self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def _forward_impl(self, x): # See note [TorchScript super()] features = [] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) features.append(x) x = self.layer2(x) features.append(x) x = self.layer3(x) features.append(x) x = self.layer4(x) features.append(x) #x = self.avgpool(x) #x = torch.flatten(x, 1) #x = self.fc(x) return features def forward(self, x): return self._forward_impl(x) def resnext101_32x8d(pretrained=True, **kwargs): """Constructs a ResNet-152 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ kwargs['groups'] = 32 kwargs['width_per_group'] = 8 model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) return model ================================================ FILE: modules/control/proc/leres/leres/__init__.py ================================================ ================================================ FILE: modules/control/proc/leres/leres/depthmap.py ================================================ # Author: thygate # https://github.com/thygate/stable-diffusion-webui-depthmap-script import gc from operator import getitem import cv2 import numpy as np import skimage.measure import torch from torchvision.transforms import transforms from modules.control.util import torch_gc whole_size_threshold = 1600 # R_max from the paper pix2pixsize = 1024 def scale_torch(img): """ Scale the image and output it in torch.tensor. :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W] :param scale: the scale factor. float :return: img. [C, H, W] """ if len(img.shape) == 2: img = img[np.newaxis, :, :] if img.shape[2] == 3: transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )]) img = transform(img.astype(np.float32)) else: img = img.astype(np.float32) img = torch.from_numpy(img) return img def estimateleres(img, model, w, h): device = next(iter(model.parameters())).device # leres transform input rgb_c = img[:, :, ::-1].copy() A_resize = cv2.resize(rgb_c, (w, h)) img_torch = scale_torch(A_resize)[None, :, :, :] # compute img_torch = img_torch.to(device) prediction = model.depth_model(img_torch) prediction = prediction.squeeze().cpu().numpy() prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) return prediction def generatemask(size): # Generates a Guassian mask mask = np.zeros(size, dtype=np.float32) sigma = int(size[0]/16) k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1 mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) mask = (mask - mask.min()) / (mask.max() - mask.min()) mask = mask.astype(np.float32) return mask def resizewithpool(img, size): i_size = img.shape[0] n = int(np.floor(i_size/size)) out = skimage.measure.block_reduce(img, (n, n), np.max) return out def rgb2gray(rgb): # Converts rgb to gray return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140]) def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000): # Returns the R_x resolution described in section 5 of the main paper. # Parameters: # img :input rgb image # basesize : size the dilation kernel which is equal to receptive field of the network. # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue. # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3. # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper) # Returns: # outputsize_scale*speed_scale :The computed R_x resolution # patch_scale: K parameter from section 6 of the paper # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search speed_scale = 32 image_dim = int(min(img.shape[0:2])) gray = rgb2gray(img) grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)) grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA) # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues m = grad.min() M = grad.max() middle = m + (0.4 * (M - m)) grad[grad < middle] = 0 grad[grad >= middle] = 1 # dilation kernel with size of the receptive field kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float) # dilation kernel with size of the a quarter of receptive field used to compute k # as described in section 6 of main paper kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float) # Output resolution limit set by the whole_size_threshold and scale_threshold. threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2])) outputsize_scale = basesize / speed_scale for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))): grad_resized = resizewithpool(grad, p_size) grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST) grad_resized[grad_resized >= 0.5] = 1 grad_resized[grad_resized < 0.5] = 0 dilated = cv2.dilate(grad_resized, kernel, iterations=1) meanvalue = (1-dilated).mean() if meanvalue > confidence: break else: outputsize_scale = p_size grad_region = cv2.dilate(grad_resized, kernel2, iterations=1) patch_scale = grad_region.mean() return int(outputsize_scale*speed_scale), patch_scale # Generate a double-input depth estimation def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel): # Generate the low resolution estimation estimate1 = singleestimate(img, size1, model, net_type) # Resize to the inference size of merge network. estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) # Generate the high resolution estimation estimate2 = singleestimate(img, size2, model, net_type) # Resize to the inference size of merge network. estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) # Inference on the merge model pix2pixmodel.set_input(estimate1, estimate2) pix2pixmodel.test() visuals = pix2pixmodel.get_current_visuals() prediction_mapped = visuals['fake_B'] prediction_mapped = (prediction_mapped+1)/2 prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / ( torch.max(prediction_mapped) - torch.min(prediction_mapped)) prediction_mapped = prediction_mapped.squeeze().cpu().numpy() return prediction_mapped # Generate a single-input depth estimation def singleestimate(img, msize, model, net_type): # if net_type == 0: return estimateleres(img, model, msize, msize) # else: # return estimatemidasBoost(img, model, msize, msize) def applyGridpatch(blsize, stride, img, box): # Extract a simple grid patch. counter1 = 0 patch_bound_list = {} for k in range(blsize, img.shape[1] - blsize, stride): for j in range(blsize, img.shape[0] - blsize, stride): patch_bound_list[str(counter1)] = {} patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize] patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1], patchbounds[2] - patchbounds[0]] patch_bound_list[str(counter1)]['rect'] = patch_bound patch_bound_list[str(counter1)]['size'] = patch_bound[2] counter1 = counter1 + 1 return patch_bound_list # Generating local patches to perform the local refinement described in section 6 of the main paper. def generatepatchs(img, base_size): # Compute the gradients as a proxy of the contextual cues. img_gray = rgb2gray(img) whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\ np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)) threshold = whole_grad[whole_grad > 0].mean() whole_grad[whole_grad < threshold] = 0 # We use the integral image to speed-up the evaluation of the amount of gradients for each patch. gf = whole_grad.sum()/len(whole_grad.reshape(-1)) grad_integral_image = cv2.integral(whole_grad) # Variables are selected such that the initial patch size would be the receptive field size # and the stride is set to 1/3 of the receptive field size. blsize = int(round(base_size/2)) stride = int(round(blsize*0.75)) # Get initial Grid patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0]) # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map. patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf) # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest # patch patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True) return patchset def getGF_fromintegral(integralimage, rect): # Computes the gradient density of a given patch from the gradient integral image. x1 = rect[1] x2 = rect[1]+rect[3] y1 = rect[0] y2 = rect[0]+rect[2] value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1] return value # Adaptively select patches def adaptiveselection(integral_grad, patch_bound_list, gf): patchlist = {} count = 0 height, width = integral_grad.shape search_step = int(32/factor) # Go through all patches for c in range(len(patch_bound_list)): # Get patch bbox = patch_bound_list[str(c)]['rect'] # Compute the amount of gradients present in the patch from the integral image. cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3]) # Check if patching is beneficial by comparing the gradient density of the patch to # the gradient density of the whole image if cgf >= gf: bbox_test = bbox.copy() patchlist[str(count)] = {} # Enlarge each patch until the gradient density of the patch is equal # to the whole image gradient density while True: bbox_test[0] = bbox_test[0] - int(search_step/2) bbox_test[1] = bbox_test[1] - int(search_step/2) bbox_test[2] = bbox_test[2] + search_step bbox_test[3] = bbox_test[3] + search_step # Check if we are still within the image if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \ or bbox_test[0] + bbox_test[2] >= width: break # Compare gradient density cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3]) if cgf < gf: break bbox = bbox_test.copy() # Add patch to selected patches patchlist[str(count)]['rect'] = bbox patchlist[str(count)]['size'] = bbox[2] count = count + 1 # Return selected patches return patchlist def impatch(image, rect): # Extract the given patch pixels from a given image. w1 = rect[0] h1 = rect[1] w2 = w1 + rect[2] h2 = h1 + rect[3] image_patch = image[h1:h2, w1:w2] return image_patch class ImageandPatchs: def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1): self.root_dir = root_dir self.patchsinfo = patchsinfo self.name = name self.patchs = patchsinfo self.scale = scale self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)), interpolation=cv2.INTER_CUBIC) self.do_have_estimate = False self.estimation_updated_image = None self.estimation_base_image = None def __len__(self): return len(self.patchs) def set_base_estimate(self, est): self.estimation_base_image = est if self.estimation_updated_image is not None: self.do_have_estimate = True def set_updated_estimate(self, est): self.estimation_updated_image = est if self.estimation_base_image is not None: self.do_have_estimate = True def __getitem__(self, index): patch_id = int(self.patchs[index][0]) rect = np.array(self.patchs[index][1]['rect']) msize = self.patchs[index][1]['size'] ## applying scale to rect: rect = np.round(rect * self.scale) rect = rect.astype('int') msize = round(msize * self.scale) patch_rgb = impatch(self.rgb_image, rect) if self.do_have_estimate: patch_whole_estimate_base = impatch(self.estimation_base_image, rect) patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect) return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base, 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect, 'size': msize, 'id': patch_id} else: return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id} def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk """ expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') """ def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix #self.print_options(opt) # set gpu ids str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids.append(id) #if len(opt.gpu_ids) > 0: # torch.cuda.set_device(opt.gpu_ids[0]) self.opt = opt return self.opt def estimateboost(img, model, model_type, pix2pixmodel, max_res=512, depthmap_script_boost_rmax=None): global whole_size_threshold # get settings if depthmap_script_boost_rmax: whole_size_threshold = depthmap_script_boost_rmax if model_type == 0: #leres net_receptive_field_size = 448 patch_netsize = 2 * net_receptive_field_size elif model_type == 1: #dpt_beit_large_512 net_receptive_field_size = 512 patch_netsize = 2 * net_receptive_field_size else: #other midas net_receptive_field_size = 384 patch_netsize = 2 * net_receptive_field_size gc.collect() torch_gc() # Generate mask used to smoothly blend the local pathc estimations to the base estimate. # It is arbitrarily large to avoid artifacts during rescaling for each crop. mask_org = generatemask((3000, 3000)) mask = mask_org.copy() # Value x of R_x defined in the section 5 of the main paper. r_threshold_value = 0.2 #if R0: # r_threshold_value = 0 input_resolution = img.shape scale_threshold = 3 # Allows up-scaling with a scale up to 3 # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the # supplementary material. whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold) # print('wholeImage being processed in :', whole_image_optimal_size) # Generate the base estimate using the double estimation. whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel) # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select # small high-density regions of the image. global factor factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2) # print('Adjust factor is:', 1/factor) # Check if Local boosting is beneficial. if max_res < whole_image_optimal_size: # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result") return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) # Compute the default target resolution. if img.shape[0] > img.shape[1]: a = 2 * whole_image_optimal_size b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0]) else: a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1]) b = 2 * whole_image_optimal_size b = int(round(b / factor)) a = int(round(a / factor)) """ # recompute a, b and saturate to max res. if max(a,b) > max_res: print('Default Res is higher than max-res: Reducing final resolution') if img.shape[0] > img.shape[1]: a = max_res b = round(max_res * img.shape[1] / img.shape[0]) else: a = round(max_res * img.shape[0] / img.shape[1]) b = max_res b = int(b) a = int(a) """ img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC) # Extract selected patches for local refinement base_size = net_receptive_field_size * 2 patchset = generatepatchs(img, base_size) # print('Target resolution: ', img.shape) # Computing a scale in case user prompted to generate the results as the same resolution of the input. # Notice that our method output resolution is independent of the input resolution and this parameter will only # enable a scaling operation during the local patch merge implementation to generate results with the same resolution # as the input. """ if output_resolution == 1: mergein_scale = input_resolution[0] / img.shape[0] print('Dynamicly change merged-in resolution; scale:', mergein_scale) else: mergein_scale = 1 """ # always rescale to input res for now mergein_scale = input_resolution[0] / img.shape[0] imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale) whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale), round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC) imageandpatchs.set_base_estimate(whole_estimate_resized.copy()) imageandpatchs.set_updated_estimate(whole_estimate_resized.copy()) print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2]) print('Patches to process: '+str(len(imageandpatchs))) # Enumerate through all patches, generate their estimations and refining the base estimate. for patch_ind in range(len(imageandpatchs)): # Get patch information patch = imageandpatchs[patch_ind] # patch object patch_rgb = patch['patch_rgb'] # rgb patch patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base rect = patch['rect'] # patch size and location patch['id'] # patch ID org_size = patch_whole_estimate_base.shape # the original size from the unscaled input print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect) # We apply double estimation for patches. The high resolution value is fixed to twice the receptive # field size of the network for patches to accelerate the process. patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel) patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) # Merging the patch estimation into the base estimate using our merge network: # We feed the patch estimation and the same region from the updated base estimate to the merge network # to generate the target estimate for the corresponding region. pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation) # Run merging network pix2pixmodel.test() visuals = pix2pixmodel.get_current_visuals() prediction_mapped = visuals['fake_B'] prediction_mapped = (prediction_mapped+1)/2 prediction_mapped = prediction_mapped.squeeze().cpu().numpy() mapped = prediction_mapped # We use a simple linear polynomial to make sure the result of the merge network would match the values of # base estimate p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1) merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape) merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC) # Get patch size and location w1 = rect[0] h1 = rect[1] w2 = w1 + rect[2] h2 = h1 + rect[3] # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size # and resize it to our needed size while merging the patches. if mask.shape != org_size: mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR) tobemergedto = imageandpatchs.estimation_updated_image # Update the whole estimation: # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless # blending at the boundaries of the patch region. tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask) imageandpatchs.set_updated_estimate(tobemergedto) # output return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) ================================================ FILE: modules/control/proc/leres/leres/multi_depth_model_woauxi.py ================================================ import torch import torch.nn as nn from . import network_auxi as network from .net_tools import get_func class RelDepthModel(nn.Module): def __init__(self, backbone='resnet50'): super(RelDepthModel, self).__init__() if backbone == 'resnet50': encoder = 'resnet50_stride32' elif backbone == 'resnext101': encoder = 'resnext101_stride32x8d' self.depth_model = DepthModel(encoder) def inference(self, rgb): input = rgb.to(self.depth_model.device) depth = self.depth_model(input) #pred_depth_out = depth - depth.min() + 0.01 return depth #pred_depth_out class DepthModel(nn.Module): def __init__(self, encoder): super(DepthModel, self).__init__() backbone = network.__name__.split('.')[-1] + '.' + encoder self.encoder_modules = get_func(backbone)() self.decoder_modules = network.Decoder() def forward(self, x): lateral_out = self.encoder_modules(x) out_logit = self.decoder_modules(lateral_out) return out_logit ================================================ FILE: modules/control/proc/leres/leres/net_tools.py ================================================ import os from collections import OrderedDict import importlib import torch def get_func(func_name): """Helper to return a function object by name. func_name must identify a function in this module or the path to a function relative to the base 'modeling' module. """ if func_name == '': return None try: parts = func_name.split('.') # Refers to a function in this module if len(parts) == 1: return globals()[parts[0]] # Otherwise, assume we're referencing a module under modeling module_name = 'modules.control.proc.leres.leres.' + '.'.join(parts[:-1]) module = importlib.import_module(module_name) return getattr(module, parts[-1]) except Exception: print('Failed to find function: %s', func_name) raise def load_ckpt(args, depth_model, shift_model, focal_model): """ Load checkpoint. """ if os.path.isfile(args.load_ckpt): print("loading checkpoint %s" % args.load_ckpt) checkpoint = torch.load(args.load_ckpt) if shift_model is not None: shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), strict=True) if focal_model is not None: focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), strict=True) depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True) del checkpoint if torch.cuda.is_available(): torch.cuda.empty_cache() def strip_prefix_if_present(state_dict, prefix): keys = sorted(state_dict.keys()) if not all(key.startswith(prefix) for key in keys): return state_dict stripped_state_dict = OrderedDict() for key, value in state_dict.items(): stripped_state_dict[key.replace(prefix, "")] = value return stripped_state_dict ================================================ FILE: modules/control/proc/leres/leres/network_auxi.py ================================================ import torch import torch.nn as nn import torch.nn.init as init from . import Resnet, Resnext_torch def resnet50_stride32(): return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2]) def resnext101_stride32x8d(): return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2]) class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.inchannels = [256, 512, 1024, 2048] self.midchannels = [256, 256, 256, 512] self.upfactors = [2,2,2,2] self.outchannels = 1 self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3]) self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True) self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True) self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2]) self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1]) self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0]) self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2) self._init_params() def _init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, features): x_32x = self.conv(features[3]) # 1/32 x_32 = self.conv1(x_32x) x_16 = self.upsample(x_32) # 1/16 x_8 = self.ffm2(features[2], x_16) # 1/8 x_4 = self.ffm1(features[1], x_8) # 1/4 x_2 = self.ffm0(features[0], x_4) # 1/2 #----------------------------------------- x = self.outconv(x_2) # original size return x class DepthNet(nn.Module): __factory = { 18: Resnet.resnet18, 34: Resnet.resnet34, 50: Resnet.resnet50, 101: Resnet.resnet101, 152: Resnet.resnet152 } def __init__(self, backbone='resnet', depth=50, upfactors=None): if upfactors is None: upfactors = [2, 2, 2, 2] super(DepthNet, self).__init__() self.backbone = backbone self.depth = depth self.pretrained = False self.inchannels = [256, 512, 1024, 2048] self.midchannels = [256, 256, 256, 512] self.upfactors = upfactors self.outchannels = 1 # Build model if self.backbone == 'resnet': if self.depth not in DepthNet.__factory: raise KeyError("Unsupported depth:", self.depth) self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained) elif self.backbone == 'resnext101_32x8d': self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained) else: self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained) def forward(self, x): x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4 return x class FTB(nn.Module): def __init__(self, inchannels, midchannels=512): super(FTB, self).__init__() self.in1 = inchannels self.mid = midchannels self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True) # NN.BatchNorm2d self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True), \ nn.BatchNorm2d(num_features=self.mid), \ nn.ReLU(inplace=True), \ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)) self.relu = nn.ReLU(inplace=True) self.init_params() def forward(self, x): x = self.conv1(x) x = x + self.conv_branch(x) x = self.relu(x) return x def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) class ATA(nn.Module): def __init__(self, inchannels, reduction=8): super(ATA, self).__init__() self.inchannels = inchannels self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction), nn.ReLU(inplace=True), nn.Linear(self.inchannels // reduction, self.inchannels), nn.Sigmoid()) self.init_params() def forward(self, low_x, high_x): n, c, _, _ = low_x.size() x = torch.cat([low_x, high_x], 1) x = self.avg_pool(x) x = x.view(n, -1) x = self.fc(x).view(n, c, 1, 1) x = low_x * x + high_x return x def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # init.kaiming_normal_(m.weight, mode='fan_out') # init.normal(m.weight, std=0.01) init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # init.kaiming_normal_(m.weight, mode='fan_out') # init.normal_(m.weight, std=0.01) init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) class FFM(nn.Module): def __init__(self, inchannels, midchannels, outchannels, upfactor=2): super(FFM, self).__init__() self.inchannels = inchannels self.midchannels = midchannels self.outchannels = outchannels self.upfactor = upfactor self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) # self.ata = ATA(inchannels = self.midchannels) self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) self.init_params() def forward(self, low_x, high_x): x = self.ftb1(low_x) x = x + high_x x = self.ftb2(x) x = self.upsample(x) return x def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) class AO(nn.Module): # Adaptive output module def __init__(self, inchannels, outchannels, upfactor=2): super(AO, self).__init__() self.inchannels = inchannels self.outchannels = outchannels self.upfactor = upfactor self.adapt_conv = nn.Sequential( nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1, stride=1, bias=True), \ nn.BatchNorm2d(num_features=self.inchannels // 2), \ nn.ReLU(inplace=True), \ nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True), \ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)) self.init_params() def forward(self, x): x = self.adapt_conv(x) return x def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) # ============================================================================================================== class ResidualConv(nn.Module): def __init__(self, inchannels): super(ResidualConv, self).__init__() # NN.BatchNorm2d self.conv = nn.Sequential( # nn.BatchNorm2d(num_features=inchannels), nn.ReLU(inplace=False), # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True), # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True) nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1, bias=False), nn.BatchNorm2d(num_features=inchannels / 2), nn.ReLU(inplace=False), nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False) ) self.init_params() def forward(self, x): x = self.conv(x) + x return x def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) class FeatureFusion(nn.Module): def __init__(self, inchannels, outchannels): super(FeatureFusion, self).__init__() self.conv = ResidualConv(inchannels=inchannels) # NN.BatchNorm2d self.up = nn.Sequential(ResidualConv(inchannels=inchannels), nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(num_features=outchannels), nn.ReLU(inplace=True)) def forward(self, lowfeat, highfeat): return self.up(highfeat + self.conv(lowfeat)) def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # init.kaiming_normal_(m.weight, mode='fan_out') init.normal_(m.weight, std=0.01) # init.xavier_normal_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) class SenceUnderstand(nn.Module): def __init__(self, channels): super(SenceUnderstand, self).__init__() self.channels = channels self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), nn.ReLU(inplace=True)) self.pool = nn.AdaptiveAvgPool2d(8) self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels), nn.ReLU(inplace=True)) self.conv2 = nn.Sequential( nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0), nn.ReLU(inplace=True)) self.initial_params() def forward(self, x): n, c, h, w = x.size() x = self.conv1(x) x = self.pool(x) x = x.view(n, -1) x = self.fc(x) x = x.view(n, self.channels, 1, 1) x = self.conv2(x) x = x.repeat(1, 1, h, w) return x def initial_params(self, dev=0.01): for m in self.modules(): if isinstance(m, nn.Conv2d): # print torch.sum(m.weight) m.weight.data.normal_(0, dev) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, nn.ConvTranspose2d): # print torch.sum(m.weight) m.weight.data.normal_(0, dev) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, nn.Linear): m.weight.data.normal_(0, dev) if __name__ == '__main__': net = DepthNet(depth=50, pretrained=True) print(net) inputs = torch.ones(4,3,128,128) out = net(inputs) print(out.size()) ================================================ FILE: modules/control/proc/leres/pix2pix/LICENSE ================================================ https://github.com/compphoto/BoostingMonocularDepth Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved. This software is for academic use only. A redistribution of this software, with or without modifications, has to be for academic use only, while giving the appropriate credit to the original authors of the software. The methods implemented as a part of this software may be covered under patents or patent applications. THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: modules/control/proc/leres/pix2pix/__init__.py ================================================ ================================================ FILE: modules/control/proc/leres/pix2pix/models/__init__.py ================================================ """This package contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. You need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate loss, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. In the function <__init__>, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. Now you can use the model class by specifying flag '--model dummy'. See our template model class 'template_model.py' for more details. """ import importlib from .base_model import BaseModel def find_model_using_name(model_name): """Import the module "models/[model_name]_model.py". In the file, the class called DatasetNameModel() will be instantiated. It has to be a subclass of BaseModel, and it is case-insensitive. """ model_filename = "modules.control.proc.leres.pix2pix.models." + model_name + "_model" modellib = importlib.import_module(model_filename) model = None target_model_name = model_name.replace('_', '') + 'model' for name, cls in modellib.__dict__.items(): if name.lower() == target_model_name.lower() \ and issubclass(cls, BaseModel): model = cls if model is None: print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) exit(0) return model def get_option_setter(model_name): """Return the static method of the model class.""" model_class = find_model_using_name(model_name) return model_class.modify_commandline_options def create_model(opt): """Create a model given the option. This function warps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from models import create_model >>> model = create_model(opt) """ model = find_model_using_name(opt.model) instance = model(opt) print("model [%s] was created" % type(instance).__name__) return instance ================================================ FILE: modules/control/proc/leres/pix2pix/models/base_model.py ================================================ import gc import os from abc import ABC, abstractmethod from collections import OrderedDict import torch from modules.control.util import torch_gc from . import networks class BaseModel(ABC): """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions When creating your custom class, you need to implement your own initialization. In this function, you should first call Then, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' @staticmethod def modify_commandline_options(parser, is_train): """Add new model-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): includes the data itself and its metadata information. """ pass @abstractmethod def forward(self): """Run forward pass; called by both functions and .""" pass @abstractmethod def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass def setup(self, opt): """Load and print networks; create schedulers Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ if self.isTrain: self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch self.load_networks(load_suffix) self.print_networks(opt.verbose) def eval(self): """Make models eval mode during test time""" for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) net.eval() def test(self): """Forward function used in test time. It also calls to produce additional visualization results """ self.forward() self.compute_visuals() def compute_visuals(self): # noqa """Calculate additional output images for visdom and HTML visualization""" pass def get_image_paths(self): """ Return image paths that are used to load current data""" return self.image_paths def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" old_lr = self.optimizers[0].param_groups[0]['lr'] for scheduler in self.schedulers: if self.opt.lr_policy == 'plateau': scheduler.step(self.metric) else: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate %.7f -> %.7f' % (old_lr, lr)) def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret def get_current_losses(self): """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number return errors_ret def save_networks(self, epoch): """Save all the networks to the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def unload_network(self, name): """Unload network and gc. """ if isinstance(name, str): net = getattr(self, 'net' + name) del net gc.collect() torch_gc() return None def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) def load_networks(self, epoch): """Load all the networks from the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module # print('Loading depth boost model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) if hasattr(state_dict, '_metadata'): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture Parameters: verbose (bool) -- if verbose: print the network architecture """ print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('-----------------------------------------------') def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad ================================================ FILE: modules/control/proc/leres/pix2pix/models/base_model_hg.py ================================================ import os import torch class BaseModelHG(): def name(self): return 'BaseModel' def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) def set_input(self, input): self.input = input def forward(self): pass # used in test time, no backprop def test(self): pass def get_image_paths(self): pass def optimize_parameters(self): pass def get_current_visuals(self): return self.input def get_current_errors(self): return {} def save(self, label): pass # helper saving function that can be used by subclasses def save_network(self, network, network_label, epoch_label, gpu_ids): save_filename = '_%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): network.cuda(device_id=gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) print(save_path) model = torch.load(save_path) return model # network.load_state_dict(torch.load(save_path)) def update_learning_rate(): pass ================================================ FILE: modules/control/proc/leres/pix2pix/models/networks.py ================================================ import torch import torch.nn as nn from torch.nn import init import functools from torch.optim import lr_scheduler ############################################################################### # Helper Functions ############################################################################### class Identity(nn.Module): def forward(self, x): return x def get_norm_layer(norm_type='instance'): """Return a normalization layer Parameters: norm_type (str) -- the name of the normalization layer: batch | instance | none For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. """ if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) elif norm_type == 'none': def norm_layer(x): return Identity() else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def get_scheduler(optimizer, opt): """Return a learning rate scheduler Parameters: optimizer -- the optimizer of the network opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine For 'linear', we keep the same learning rate for the first epochs and linearly decay the rate to zero over the next epochs. For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. See https://pytorch.org/docs/stable/optim.html for more details. """ if opt.lr_policy == 'linear': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) elif opt.lr_policy == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def init_weights(net, init_type='normal', init_gain=0.02): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) # print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=None): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if gpu_ids is None: gpu_ids = [] if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=None): """Create a generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 norm (str) -- the name of normalization layers used in the network: batch | instance | none use_dropout (bool) -- if use dropout layers. init_type (str) -- the name of our initialization method. init_gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Returns a generator Our current implementation provides two types of generators: U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) The original U-Net paper: https://arxiv.org/abs/1505.04597 Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). The generator has been initialized by . It uses RELU for non-linearity. """ if gpu_ids is None: gpu_ids = [] net = None norm_layer = get_norm_layer(norm_type=norm) if netG == 'resnet_9blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) elif netG == 'resnet_6blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) elif netG == 'resnet_12blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12) elif netG == 'unet_128': net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_672': net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_960': net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_1024': net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids) def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=None): """Create a discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the first conv layer netD (str) -- the architecture's name: basic | n_layers | pixel n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' norm (str) -- the type of normalization layers used in the network. init_type (str) -- the name of the initialization method. init_gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Returns a discriminator Our current implementation provides three types of discriminators: [basic]: 'PatchGAN' classifier described in the original pix2pix paper. It can classify whether 70x70 overlapping patches are real or fake. Such a patch-level discriminator architecture has fewer parameters than a full-image discriminator and can work on arbitrarily-sized images in a fully convolutional fashion. [n_layers]: With this mode, you can specify the number of conv layers in the discriminator with the parameter (default=3 as used in [basic] (PatchGAN).) [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. It encourages greater color diversity but has no effect on spatial statistics. The discriminator has been initialized by . It uses Leakly RELU for non-linearity. """ if gpu_ids is None: gpu_ids = [] net = None norm_layer = get_norm_layer(norm_type=norm) if netD == 'basic': # default PatchGAN classifier net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) elif netD == 'n_layers': # more options net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) return init_net(net, init_type, init_gain, gpu_ids) ############################################################################## # Classes ############################################################################## class GANLoss(nn.Module): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode in ['wgangp']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. Parameters: prediction (tensor) - - tpyically the prediction from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: A label tensor filled with ground truth label, and with the size of the input """ if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(prediction) def __call__(self, prediction, target_is_real): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: the calculated loss. """ if self.gan_mode in ['lsgan', 'vanilla']: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) elif self.gan_mode == 'wgangp': if target_is_real: loss = -prediction.mean() else: loss = prediction.mean() return loss def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 Arguments: netD (network) -- discriminator network real_data (tensor array) -- real images fake_data (tensor array) -- generated images from the generator device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') type (str) -- if we mix real and fake data or not [real | fake | mixed]. constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 lambda_gp (float) -- weight for this loss Returns the gradient penalty loss """ if lambda_gp > 0.0: if type == 'real': # either use real images, fake images, or a linear interpolation of two. interpolatesv = real_data elif type == 'fake': interpolatesv = fake_data elif type == 'mixed': alpha = torch.rand(real_data.shape[0], 1, device=device) alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) else: raise NotImplementedError('{} not implemented'.format(type)) interpolatesv.requires_grad_(True) disc_interpolates = netD(interpolatesv) gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, grad_outputs=torch.ones(disc_interpolates.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True) gradients = gradients[0].view(real_data.size(0), -1) # flat the data gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps return gradient_penalty, gradients else: return 0.0, None class ResnetGenerator(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for _i in range(n_blocks): # add ResNet blocks model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): """Standard forward""" return self.model(input) class ResnetBlock(nn.Module): """Define a Resnet block""" def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): """Construct a convolutional block. Parameters: dim (int) -- the number of channels in the conv layer. padding_type (str) -- the name of padding layer: reflect | replicate | zero norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. use_bias (bool) -- if the conv layer uses bias or not Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): """Forward function (with skip connections)""" out = x + self.conv_block(x) # add skip connections return out class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer for _i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer def forward(self, input): """Standard forward""" return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. X -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet submodule with skip connections. Parameters: outer_nc (int) -- the number of filters in the outer conv layer inner_nc (int) -- the number of filters in the inner conv layer input_nc (int) -- the number of channels in input images/features submodule (UnetSkipConnectionBlock) -- previously defined submodules outermost (bool) -- if this module is the outermost module innermost (bool) -- if this module is the innermost module norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) class PixelDiscriminator(nn.Module): """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): """Construct a 1x1 PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer """ super(PixelDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] self.net = nn.Sequential(*self.net) def forward(self, input): """Standard forward.""" return self.net(input) ================================================ FILE: modules/control/proc/leres/pix2pix/models/pix2pix4depth_model.py ================================================ import torch from .base_model import BaseModel from . import networks class Pix2Pix4DepthModel(BaseModel): """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. The model training requires '--dataset_mode aligned' dataset. By default, it uses a '--netG unet256' U-Net generator, a '--netD basic' discriminator (PatchGAN), and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. For pix2pix, we do not use image buffer The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. """ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge') if is_train: parser.set_defaults(pool_size=0, gan_mode='vanilla',) parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss') return parser def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] # self.loss_names = ['G_L1'] # specify the images you want to save/display. The training/test scripts will call if self.isTrain: self.visual_names = ['outer','inner', 'fake_B', 'real_B'] else: self.visual_names = ['fake_B'] # specify the models you want to save to the disk. The training/test scripts will call and if self.isTrain: self.model_names = ['G','D'] else: # during test time, only load G self.model_names = ['G'] # define networks (both generator and discriminator) self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none', False, 'normal', 0.02, self.gpu_ids) if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers; schedulers will be automatically created by function . self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input_train(self, input): self.outer = input['data_outer'].to(self.device) self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False) self.inner = input['data_inner'].to(self.device) self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False) self.image_paths = input['image_path'] if self.isTrain: self.gtfake = input['data_gtfake'].to(self.device) self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False) self.real_B = self.gtfake self.real_A = torch.cat((self.outer, self.inner), 1) def set_input(self, outer, inner): inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0) outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0) inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner)) outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer)) inner = self.normalize(inner) outer = self.normalize(outer) self.real_A = torch.cat((outer, inner), 1).to(self.device) def normalize(self, input): input = input * 2 input = input - 1 return input def forward(self): """Run forward pass; called by both functions and .""" self.fake_B = self.netG(self.real_A) # G(A) def backward_D(self): """Calculate GAN loss for the discriminator""" # Fake; stop backprop to the generator by detaching fake_B fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator pred_fake = self.netD(fake_AB.detach()) self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) pred_real = self.netD(real_AB) self.loss_D_real = self.criterionGAN(pred_real, True) # combine loss and calculate gradients self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D.backward() def backward_G(self): """Calculate GAN and L1 loss for the generator""" # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 # combine loss and calculate gradients self.loss_G = self.loss_G_L1 + self.loss_G_GAN self.loss_G.backward() def optimize_parameters(self): self.forward() # compute fake images: G(A) # update D self.set_requires_grad(self.netD, True) # enable backprop for D self.optimizer_D.zero_grad() # set D's gradients to zero self.backward_D() # calculate gradients for D self.optimizer_D.step() # update D's weights # update G self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G self.optimizer_G.zero_grad() # set G's gradients to zero self.backward_G() # calculate graidents for G self.optimizer_G.step() # udpate G's weights ================================================ FILE: modules/control/proc/leres/pix2pix/options/__init__.py ================================================ """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" ================================================ FILE: modules/control/proc/leres/pix2pix/options/base_options.py ================================================ import argparse import os from ...pix2pix.util import util # import torch from ...pix2pix import models # import pix2pix.data import numpy as np class BaseOptions(): """This class defines options used during both training and test time. It also implements several helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in functions in both dataset class and model class. """ def __init__(self): """Reset the class; indicates the class hasn't been initailized""" self.initialized = False def initialize(self, parser): """Define the common options that are used in both training and test.""" # basic parameters parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet') parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here') # model parameters parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale') parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') # dataset parameters parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') parser.add_argument('--batch_size', type=int, default=1, help='input batch size') parser.add_argument('--load_size', type=int, default=672, help='scale images to this size') parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size') parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') # additional parameters parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') parser.add_argument('--data_dir', type=str, required=False, help='input files directory images can be .png .jpg .tiff') parser.add_argument('--output_dir', type=str, required=False, help='result dir. result depth will be png. vides are JMPG as avi') parser.add_argument('--savecrops', type=int, required=False) parser.add_argument('--savewholeest', type=int, required=False) parser.add_argument('--output_resolution', type=int, required=False, help='0 for no restriction 1 for resize to input size') parser.add_argument('--net_receptive_field_size', type=int, required=False) parser.add_argument('--pix2pixsize', type=int, required=False) parser.add_argument('--generatevideo', type=int, required=False) parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL') parser.add_argument('--R0', action='store_true') parser.add_argument('--R20', action='store_true') parser.add_argument('--Final', action='store_true') parser.add_argument('--colorize_results', action='store_true') parser.add_argument('--max_res', type=float, default=np.inf) self.initialized = True return parser def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options opt, _ = parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) opt, _ = parser.parse_known_args() # parse again with new defaults # modify dataset-related parser options # dataset_name = opt.dataset_mode # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name) # parser = dataset_option_setter(parser, self.isTrain) # save and return the parser self.parser = parser #return parser.parse_args() #EVIL return opt def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix #self.print_options(opt) # set gpu ids str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids.append(id) #if len(opt.gpu_ids) > 0: # torch.cuda.set_device(opt.gpu_ids[0]) self.opt = opt return self.opt ================================================ FILE: modules/control/proc/leres/pix2pix/options/test_options.py ================================================ from .base_options import BaseOptions class TestOptions(BaseOptions): """This class includes test options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # define shared options parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') # Dropout and Batchnorm has different behavioir during training and test. parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') # rewrite devalue values parser.set_defaults(model='pix2pix4depth') # To avoid cropping, the load_size should be the same as crop_size parser.set_defaults(load_size=parser.get_default('crop_size')) self.isTrain = False return parser ================================================ FILE: modules/control/proc/leres/pix2pix/util/__init__.py ================================================ """This package includes a miscellaneous collection of useful helper functions.""" ================================================ FILE: modules/control/proc/leres/pix2pix/util/util.py ================================================ """This module contains simple helper functions """ from __future__ import print_function import torch import numpy as np from PIL import Image import os def tensor2im(input_image, imtype=np.uint16): """"Converts a Tensor array into a numpy image array. Parameters: input_image (tensor) -- the input image tensor array imtype (type) -- the desired type of the converted numpy array """ if not isinstance(input_image, np.ndarray): if isinstance(input_image, torch.Tensor): # get the data from a variable image_tensor = input_image.data else: return input_image image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) # else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype) def diagnose_network(net, name='network'): """Calculate and print the mean of average absolute(gradients) Parameters: net (torch network) -- Torch network name (str) -- the name of the network """ mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path, aspect_ratio=1.0): """Save a numpy image to the disk Parameters: image_numpy (numpy array) -- input numpy array image_path (str) -- the path of the image """ image_pil = Image.fromarray(image_numpy) image_pil = image_pil.convert('I;16') # image_pil = Image.fromarray(image_numpy) # h, w, _ = image_numpy.shape # # if aspect_ratio > 1.0: # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) # if aspect_ratio < 1.0: # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) image_pil.save(image_path) def print_numpy(x, val=True, shp=False): """Print the mean, min, max, median, std, and size of a numpy array Parameters: val (bool) -- if print the values of the numpy array shp (bool) -- if print the shape of the numpy array """ x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): """create empty directories if they don't exist Parameters: paths (str list) -- a list of directory paths """ if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): """create a single empty directory if it didn't exist Parameters: path (str) -- a single directory path """ if not os.path.exists(path): os.makedirs(path) ================================================ FILE: modules/control/proc/lineart.py ================================================ import os import cv2 import numpy as np import torch import torch.nn as nn from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image norm_layer = nn.InstanceNorm2d class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() conv_block = [ nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features), nn.ReLU(inplace=True), nn.ReflectionPad2d(1), nn.Conv2d(in_features, in_features, 3), norm_layer(in_features) ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class Generator(nn.Module): def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): super(Generator, self).__init__() # Initial convolution block model0 = [ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True) ] self.model0 = nn.Sequential(*model0) # Downsampling model1 = [] in_features = 64 out_features = in_features*2 for _ in range(2): model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features*2 self.model1 = nn.Sequential(*model1) model2 = [] # Residual blocks for _ in range(n_residual_blocks): model2 += [ResidualBlock(in_features)] self.model2 = nn.Sequential(*model2) # Upsampling model3 = [] out_features = in_features//2 for _ in range(2): model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True) ] in_features = out_features out_features = in_features//2 self.model3 = nn.Sequential(*model3) # Output layer model4 = [ nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] if sigmoid: model4 += [nn.Sigmoid()] self.model4 = nn.Sequential(*model4) def forward(self, x, cond=None): # pylint: disable=unused-argument out = self.model0(x) out = self.model1(out) out = self.model2(out) out = self.model3(out) out = self.model4(out) return out class LineartDetector: def __init__(self, model, coarse_model): self.model = model self.model_coarse = coarse_model @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, coarse_filename=None, cache_dir=None, local_files_only=False): filename = filename or "sk_model.pth" coarse_filename = coarse_filename or "sk_model2.pth" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) coarse_model_path = os.path.join(pretrained_model_or_path, coarse_filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) coarse_model_path = hf_hub_download(pretrained_model_or_path, coarse_filename, cache_dir=cache_dir, local_files_only=local_files_only) model = Generator(3, 1, 3) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() coarse_model = Generator(3, 1, 3) coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu'))) coarse_model.eval() return cls(model, coarse_model) def to(self, device): self.model.to(device) self.model_coarse.to(device) return self def __call__(self, input_image, coarse=False, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) model = self.model_coarse if coarse else self.model assert input_image.ndim == 3 image = input_image image = torch.from_numpy(image).float().to(device) image = image / 255.0 image = rearrange(image, 'h w c -> 1 c h w') line = model(image)[0][0] line = line.cpu().numpy() line = (line * 255.0).clip(0, 255).astype(np.uint8) detected_map = line detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) detected_map = 255 - detected_map if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/lineart_anime.py ================================================ import functools import os import cv2 import numpy as np import torch import torch.nn as nn from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer def forward(self, input): # pylint: disable=redefined-builtin """Standard forward""" return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. X -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet submodule with skip connections. Parameters: outer_nc (int) -- the number of filters in the outer conv layer inner_nc (int) -- the number of filters in the inner conv layer input_nc (int) -- the number of channels in input images/features submodule (UnetSkipConnectionBlock) -- previously defined submodules outermost (bool) -- if this module is the outermost module innermost (bool) -- if this module is the innermost module norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) class LineartAnimeDetector: def __init__(self, model): self.model = model @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): filename = filename or "netG.pth" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) ckpt = torch.load(model_path) for key in list(ckpt.keys()): if 'module.' in key: ckpt[key.replace('module.', '')] = ckpt[key] del ckpt[key] net.load_state_dict(ckpt) net.eval() return cls(net) def to(self, device): self.model.to(device) return self def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) H, W, _C = input_image.shape Hn = 256 * int(np.ceil(float(H) / 256.0)) Wn = 256 * int(np.ceil(float(W) / 256.0)) img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC) image_feed = torch.from_numpy(img).float().to(device) image_feed = image_feed / 127.5 - 1.0 image_feed = rearrange(image_feed, 'h w c -> 1 c h w') line = self.model(image_feed)[0, 0] * 127.5 + 127.5 line = line.cpu().numpy() line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC) line = line.clip(0, 255).astype(np.uint8) detected_map = line detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) detected_map = 255 - detected_map if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/marigold/__init__.py ================================================ import torch from PIL import Image from modules.control.util import HWC3, resize_image from modules import devices from modules.shared import opts from .marigold_pipeline import MarigoldPipeline class MarigoldDetector: def __init__(self, model): self.model: MarigoldPipeline = model @classmethod def from_pretrained(cls, pretrained_model_or_path, cache_dir=None, **load_config): model = MarigoldPipeline.from_pretrained(pretrained_model_or_path, cache_dir=cache_dir, **load_config) return cls(model) def to(self, device): self.model.to(device) return self def __call__( self, input_image: Image, denoising_steps: int = 10, ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, color_map: str = "Spectral", output_type=None, ): self.model.to(device=devices.device, dtype=torch.float16) res = self.model( input_image, denoising_steps=denoising_steps, ensemble_size=ensemble_size, processing_res=processing_res, match_input_res=match_input_res, color_map=color_map if color_map != 'None' else 'Spectral', batch_size=1, show_progress_bar=True, ) depth_map = res.depth_colored if color_map != 'None' else res.depth_np if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": return Image.fromarray(depth_map) else: return depth_map ================================================ FILE: modules/control/proc/marigold/marigold_pipeline.py ================================================ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- from typing import Dict, Union import torch from torch.utils.data import DataLoader, TensorDataset import numpy as np from tqdm.auto import tqdm from PIL import Image from diffusers import ( DiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, ) from diffusers.utils import BaseOutput from transformers import CLIPTextModel, CLIPTokenizer from .util.image_util import chw2hwc, colorize_depth_maps, resize_max_res from .util.batchsize import find_batch_size from .util.ensemble import ensemble_depths class MarigoldDepthOutput(BaseOutput): """ Output class for Marigold monocular depth prediction pipeline. Args: depth_np (`np.ndarray`): Predicted depth map, with depth values in the range of [0, 1]. depth_colored (`PIL.Image.Image`): Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. uncertainty (`None` or `np.ndarray`): Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. """ depth_np: np.ndarray depth_colored: Image.Image uncertainty: Union[None, np.ndarray] class MarigoldPipeline(DiffusionPipeline): """ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (`UNet2DConditionModel`): Conditional U-Net to denoise the depth latent, conditioned on image latent. vae (`AutoencoderKL`): Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps to and from latent representations. scheduler (`DDIMScheduler`): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (`CLIPTextModel`): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. """ rgb_latent_scale_factor = 0.18215 depth_latent_scale_factor = 0.18215 def __init__( self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: DDIMScheduler, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, ): super().__init__() self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, ) self.empty_text_embed = None @torch.no_grad() def __call__( self, input_image: Image, denoising_steps: int = 10, ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, batch_size: int = 0, color_map: str = "Spectral", show_progress_bar: bool = True, ensemble_kwargs: Dict = None, ) -> MarigoldDepthOutput: """ Function invoked when calling the pipeline. Args: input_image (`Image`): Input RGB (or gray-scale) image. processing_res (`int`, *optional*, defaults to `768`): Maximum resolution of processing. If set to 0: will not resize at all. match_input_res (`bool`, *optional*, defaults to `True`): Resize depth prediction to match input resolution. Only valid if `limit_input_res` is not None. denoising_steps (`int`, *optional*, defaults to `10`): Number of diffusion denoising steps (DDIM) during inference. ensemble_size (`int`, *optional*, defaults to `10`): Number of predictions to be ensembled. batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. color_map (`str`, *optional*, defaults to `"Spectral"`): Colormap used to colorize the depth map. ensemble_kwargs (`dict`, *optional*, defaults to `None`): Arguments for detailed ensembling settings. Returns: `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1] - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. None if `ensemble_size = 1` """ device = self.device input_size = input_image.size if not match_input_res: assert ( processing_res is not None ), "Value error: `resize_output_back` is only valid with " assert processing_res >= 0 assert denoising_steps >= 1 assert ensemble_size >= 1 # ----------------- Image Preprocess ----------------- # Resize image if processing_res > 0: input_image = resize_max_res( input_image, max_edge_resolution=processing_res ) # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel input_image = input_image.convert("RGB") image = np.asarray(input_image) # Normalize rgb values rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] rgb_norm = rgb / 255.0 rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) rgb_norm = rgb_norm.to(device) assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 # ----------------- Predicting depth ----------------- # Batch repeated input image duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) single_rgb_dataset = TensorDataset(duplicated_rgb) if batch_size > 0: _bs = batch_size else: _bs = find_batch_size( ensemble_size=ensemble_size, input_res=max(rgb_norm.shape[1:]), dtype=self.dtype, ) single_rgb_loader = DataLoader( single_rgb_dataset, batch_size=_bs, shuffle=False ) # Predict depth maps (batched) depth_pred_ls = [] if show_progress_bar: iterable = tqdm( single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False ) else: iterable = single_rgb_loader for batch in iterable: (batched_img,) = batch depth_pred_raw = self.single_infer( rgb_in=batched_img, num_inference_steps=denoising_steps, show_pbar=show_progress_bar, ) depth_pred_ls.append(depth_pred_raw.detach().clone()) depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() torch.cuda.empty_cache() # clear vram cache for ensembling # ----------------- Test-time ensembling ----------------- if ensemble_size > 1: depth_pred, pred_uncert = ensemble_depths( depth_preds, **(ensemble_kwargs or {}) ) else: depth_pred = depth_preds pred_uncert = None # ----------------- Post processing ----------------- # Scale prediction to [0, 1] min_d = torch.min(depth_pred) max_d = torch.max(depth_pred) depth_pred = (depth_pred - min_d) / (max_d - min_d) # Convert to numpy depth_pred = depth_pred.to(torch.float32).cpu().numpy() # Resize back to original resolution if match_input_res: pred_img = Image.fromarray(depth_pred) pred_img = pred_img.resize(input_size) depth_pred = np.asarray(pred_img) # Clip output range depth_pred = depth_pred.clip(0, 1) # Colorize depth_colored = colorize_depth_maps( depth_pred, 0, 1, cmap=color_map ).squeeze() # [3, H, W], value in (0, 1) depth_colored = (depth_colored * 255).astype(np.uint8) depth_colored_hwc = chw2hwc(depth_colored) depth_colored_img = Image.fromarray(depth_colored_hwc) return MarigoldDepthOutput( depth_np=depth_pred, depth_colored=depth_colored_img, uncertainty=pred_uncert, ) def __encode_empty_text(self): """ Encode text embedding for empty prompt """ prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) @torch.no_grad() def single_infer( self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool ) -> torch.Tensor: """ Perform an individual depth prediction without ensembling. Args: rgb_in (`torch.Tensor`): Input RGB image. num_inference_steps (`int`): Number of diffusion denoisign steps (DDIM) during inference. show_pbar (`bool`): Display a progress bar of diffusion denoising. Returns: `torch.Tensor`: Predicted depth map. """ device = rgb_in.device # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # [T] # Encode image rgb_latent = self.encode_rgb(rgb_in) # Initial depth map (noise) depth_latent = torch.randn( rgb_latent.shape, device=device, dtype=self.dtype ) # [B, 4, h, w] # Batched empty text embedding if self.empty_text_embed is None: self.__encode_empty_text() batch_empty_text_embed = self.empty_text_embed.repeat( (rgb_latent.shape[0], 1, 1) ) # [B, 2, 1024] # Denoising loop if show_pbar: iterable = tqdm( enumerate(timesteps), total=len(timesteps), leave=False, desc=" " * 4 + "Diffusion denoising", ) else: iterable = enumerate(timesteps) for _i, t in iterable: unet_input = torch.cat( [rgb_latent, depth_latent], dim=1 ) # this order is important # predict the noise residual noise_pred = self.unet( unet_input, t, encoder_hidden_states=batch_empty_text_embed ).sample # [B, 4, h, w] # compute the previous noisy sample x_t -> x_t-1 depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample torch.cuda.empty_cache() depth = self.decode_depth(depth_latent) # clip prediction depth = torch.clip(depth, -1.0, 1.0) # shift to [0, 1] depth = (depth + 1.0) / 2.0 return depth def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (`torch.Tensor`): Input RGB image to be encoded. Returns: `torch.Tensor`: Image latent. """ # encode h = self.vae.encoder(rgb_in) moments = self.vae.quant_conv(h) mean, _logvar = torch.chunk(moments, 2, dim=1) # scale latent rgb_latent = mean * self.rgb_latent_scale_factor return rgb_latent def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: """ Decode depth latent into depth map. Args: depth_latent (`torch.Tensor`): Depth latent to be decoded. Returns: `torch.Tensor`: Decoded depth map. """ # scale latent depth_latent = depth_latent / self.depth_latent_scale_factor # decode z = self.vae.post_quant_conv(depth_latent) stacked = self.vae.decoder(z) # mean of output channels depth_mean = stacked.mean(dim=1, keepdim=True) return depth_mean ================================================ FILE: modules/control/proc/marigold/util/batchsize.py ================================================ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import torch import math # Search table for suggested max. inference batch size bs_search_table = [ # tested on A100-PCIE-80GB {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, # tested on A100-PCIE-40GB {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, # tested on RTX3090, RTX4090 {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, # tested on GTX1080Ti {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, ] def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: """ Automatically search for suitable operating batch size. Args: ensemble_size (`int`): Number of predictions to be ensembled. input_res (`int`): Operating resolution of the input image. Returns: `int`: Operating batch size. """ if not torch.cuda.is_available(): return 1 total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] for settings in sorted( filtered_bs_search_table, key=lambda k: (k["res"], -k["total_vram"]), ): if input_res <= settings["res"] and total_vram >= settings["total_vram"]: bs = settings["bs"] if bs > ensemble_size: bs = ensemble_size elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: bs = math.ceil(ensemble_size / 2) return bs return 1 ================================================ FILE: modules/control/proc/marigold/util/ensemble.py ================================================ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import numpy as np import torch from scipy.optimize import minimize def inter_distances(tensors: torch.Tensor): """ To calculate the distance between each two depth maps. """ distances = [] for i, j in torch.combinations(torch.arange(tensors.shape[0])): arr1 = tensors[i : i + 1] arr2 = tensors[j : j + 1] distances.append(arr1 - arr2) dist = torch.concatenate(distances, dim=0) return dist def ensemble_depths( input_images: torch.Tensor, regularizer_strength: float = 0.02, max_iter: int = 2, tol: float = 1e-3, reduction: str = "median", max_res: int = None, ): """ To ensemble multiple affine-invariant depth images (up to scale and shift), by aligning estimating the scale and shift """ device = input_images.device dtype = input_images.dtype np_dtype = np.float32 original_input = input_images.clone() n_img = input_images.shape[0] ori_shape = input_images.shape if max_res is not None: scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) if scale_factor < 1: downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") input_images = downscaler(torch.from_numpy(input_images)).numpy() # init guess np_img = input_images.reshape((n_img, -1)).to(torch.float32).cpu().numpy() _min = np.min(np_img, axis=1) _max = np.max(np_img, axis=1) s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) input_images = input_images.to(device) # objective function def closure(x): l = len(x) s = x[: int(l / 2)] t = x[int(l / 2) :] s = torch.from_numpy(s).to(dtype=dtype).to(device) t = torch.from_numpy(t).to(dtype=dtype).to(device) transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) dists = inter_distances(transformed_arrays) sqrt_dist = torch.sqrt(torch.mean(dists**2)) if "mean" == reduction: pred = torch.mean(transformed_arrays, dim=0) elif "median" == reduction: pred = torch.median(transformed_arrays, dim=0).values else: raise ValueError near_err = torch.sqrt((0 - torch.min(pred)) ** 2) far_err = torch.sqrt((1 - torch.max(pred)) ** 2) err = sqrt_dist + (near_err + far_err) * regularizer_strength err = err.to(torch.float32).detach().cpu().numpy().astype(np_dtype) return err res = minimize( closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} ) x = res.x l = len(x) s = x[: int(l / 2)] t = x[int(l / 2) :] # Prediction s = torch.from_numpy(s).to(dtype=dtype).to(device) t = torch.from_numpy(t).to(dtype=dtype).to(device) transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) if "mean" == reduction: aligned_images = torch.mean(transformed_arrays, dim=0) std = torch.std(transformed_arrays, dim=0) uncertainty = std elif "median" == reduction: aligned_images = torch.median(transformed_arrays, dim=0).values # MAD (median absolute deviation) as uncertainty indicator abs_dev = torch.abs(transformed_arrays - aligned_images) mad = torch.median(abs_dev, dim=0).values uncertainty = mad else: raise ValueError(f"Unknown reduction method: {reduction}") # Scale and shift to [0, 1] _min = torch.min(aligned_images) _max = torch.max(aligned_images) aligned_images = (aligned_images - _min) / (_max - _min) uncertainty /= _max - _min return aligned_images, uncertainty ================================================ FILE: modules/control/proc/marigold/util/image_util.py ================================================ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import matplotlib as mpl import numpy as np import torch from PIL import Image def colorize_depth_maps( depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None ): """ Colorize depth maps. """ assert len(depth_map.shape) >= 2, "Invalid dimension" if isinstance(depth_map, torch.Tensor): depth = depth_map.detach().clone().squeeze().numpy() elif isinstance(depth_map, np.ndarray): depth = depth_map.copy().squeeze() # reshape to [ (B,) H, W ] if depth.ndim < 3: depth = depth[np.newaxis, :, :] # colorize cm = mpl.colormaps[cmap] depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 img_colored_np = np.rollaxis(img_colored_np, 3, 1) if valid_mask is not None: if isinstance(depth_map, torch.Tensor): valid_mask = valid_mask.detach().numpy() valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] if valid_mask.ndim < 3: valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] else: valid_mask = valid_mask[:, np.newaxis, :, :] valid_mask = np.repeat(valid_mask, 3, axis=1) img_colored_np[~valid_mask] = 0 if isinstance(depth_map, torch.Tensor): img_colored = torch.from_numpy(img_colored_np).float() elif isinstance(depth_map, np.ndarray): img_colored = img_colored_np return img_colored def chw2hwc(chw): assert 3 == len(chw.shape) if isinstance(chw, torch.Tensor): hwc = torch.permute(chw, (1, 2, 0)) elif isinstance(chw, np.ndarray): hwc = np.moveaxis(chw, 0, -1) return hwc def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: """ Resize image to limit maximum edge length while keeping aspect ratio. Args: img (`Image.Image`): Image to be resized. max_edge_resolution (`int`): Maximum edge length (pixel). Returns: `Image.Image`: Resized image. """ original_width, original_height = img.size downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height ) new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) resized_img = img.resize((new_width, new_height)) return resized_img ================================================ FILE: modules/control/proc/marigold/util/seed_all.py ================================================ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import random import numpy as np import torch def seed_all(seed: int = 0): """ Set random seeds of all components. """ random.seed(seed) np.random.seed(seed) # noqa torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) ================================================ FILE: modules/control/proc/mediapipe_face.py ================================================ from typing import Union import cv2 import numpy as np from PIL import Image from modules.control.util import HWC3, resize_image checked_ok = False def check_dependencies(): global checked_ok # pylint: disable=global-statement from installer import installed, install, log packages = [('mediapipe', 'mediapipe')] for pkg in packages: if not installed(pkg[1], reload=True, quiet=True): install(pkg[0], pkg[1], ignore=False) try: import mediapipe as mp # pylint: disable=unused-import checked_ok = True return True except Exception as e: log.error(f'MediaPipe: {e}') return False class MediapipeFaceDetector: def __call__(self, input_image: Union[np.ndarray, Image.Image] = None, max_faces: int = 1, min_confidence: float = 0.5, output_type: str = "pil", detect_resolution: int = 512, image_resolution: int = 512, **kwargs): if not checked_ok: if not check_dependencies(): return from .mediapipe_face_util import generate_annotation if input_image is None: raise ValueError("input_image must be defined.") if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) detected_map = generate_annotation(input_image, max_faces, min_confidence) detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/mediapipe_face_util.py ================================================ from typing import Mapping import numpy as np from modules.shared import log try: import mediapipe as mp except ImportError: log.error("Control processor MediaPipe: mediapipe not installed") mp = None if mp: mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles mp_face_detection = mp.solutions.face_detection # Only for counting faces. mp_face_mesh = mp.solutions.face_mesh mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS DrawingSpec = mp.solutions.drawing_styles.DrawingSpec PoseLandmark = mp.solutions.drawing_styles.PoseLandmark min_face_size_pixels: int = 64 f_thick = 2 f_rad = 1 right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad) head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. face_connection_spec = {} for edge in mp_face_mesh.FACEMESH_FACE_OVAL: face_connection_spec[edge] = head_draw for edge in mp_face_mesh.FACEMESH_LEFT_EYE: face_connection_spec[edge] = left_eye_draw for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: face_connection_spec[edge] = left_eyebrow_draw # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: # face_connection_spec[edge] = left_iris_draw for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: face_connection_spec[edge] = right_eye_draw for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: face_connection_spec[edge] = right_eyebrow_draw # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: # face_connection_spec[edge] = right_iris_draw for edge in mp_face_mesh.FACEMESH_LIPS: face_connection_spec[edge] = mouth_draw iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2): """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all landmarks. Until our PR is merged into mediapipe, we need this separate method.""" if len(image.shape) != 3: raise ValueError("Input image must be H,W,C.") image_rows, image_cols, image_channels = image.shape if image_channels != 3: # BGR channels raise ValueError('Input image must contain three channel bgr data.') for idx, landmark in enumerate(landmark_list.landmark): if ( (landmark.HasField('visibility') and landmark.visibility < 0.9) or (landmark.HasField('presence') and landmark.presence < 0.5) ): continue if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: continue image_x = int(image_cols*landmark.x) image_y = int(image_rows*landmark.y) draw_color = None if isinstance(drawing_spec, Mapping): if drawing_spec.get(idx) is None: continue else: draw_color = drawing_spec[idx].color elif isinstance(drawing_spec, DrawingSpec): draw_color = drawing_spec.color image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color def reverse_channels(image): """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB.""" # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order. # im[:,:,::[2,1,0]] would also work but makes a copy of the data. return image[:, :, ::-1] def generate_annotation( img_rgb, max_faces: int, min_confidence: float ): """ Find up to 'max_faces' inside the provided input image. If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many pixels in the image. """ if mp is None: return img_rgb with mp_face_mesh.FaceMesh( static_image_mode=True, max_num_faces=max_faces, refine_landmarks=True, min_detection_confidence=min_confidence, ) as facemesh: img_height, img_width, img_channels = img_rgb.shape assert img_channels == 3 results = facemesh.process(img_rgb).multi_face_landmarks if results is None: print("No faces detected in controlnet image for Mediapipe face annotator.") return np.zeros_like(img_rgb) # Filter faces that are too small filtered_landmarks = [] for lm in results: landmarks = lm.landmark face_rect = [ landmarks[0].x, landmarks[0].y, landmarks[0].x, landmarks[0].y, ] # Left, up, right, down. for i in range(len(landmarks)): face_rect[0] = min(face_rect[0], landmarks[i].x) face_rect[1] = min(face_rect[1], landmarks[i].y) face_rect[2] = max(face_rect[2], landmarks[i].x) face_rect[3] = max(face_rect[3], landmarks[i].y) if min_face_size_pixels > 0: face_width = abs(face_rect[2] - face_rect[0]) face_height = abs(face_rect[3] - face_rect[1]) face_width_pixels = face_width * img_width face_height_pixels = face_height * img_height face_size = min(face_width_pixels, face_height_pixels) if face_size >= min_face_size_pixels: filtered_landmarks.append(lm) else: filtered_landmarks.append(lm) # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start. empty = np.zeros_like(img_rgb) # Draw detected faces: for face_landmarks in filtered_landmarks: mp_drawing.draw_landmarks( empty, face_landmarks, connections=face_connection_spec.keys(), landmark_drawing_spec=None, connection_drawing_spec=face_connection_spec ) draw_pupils(empty, face_landmarks, iris_landmark_spec, 2) # Flip BGR back to RGB. empty = reverse_channels(empty).copy() return empty ================================================ FILE: modules/control/proc/midas/LICENSE ================================================ MIT License Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: modules/control/proc/midas/__init__.py ================================================ import os import cv2 import numpy as np import torch from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image from modules.control.util import HWC3, resize_image from modules import devices from modules.shared import opts from .api import MiDaSInference class MidasDetector: def __init__(self, model): self.model = model @classmethod def from_pretrained(cls, pretrained_model_or_path, model_type="dpt_hybrid", filename=None, cache_dir=None, local_files_only=False): if pretrained_model_or_path == "lllyasviel/ControlNet": filename = filename or "annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" else: filename = filename or "dpt_hybrid-midas-501f0c75.pt" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) model = MiDaSInference(model_type=model_type, model_path=model_path) return cls(model) def to(self, device): self.model.to(device) return self def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, image_resolution=512, output_type=None): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) output_type = output_type or "pil" else: output_type = output_type or "np" input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 image_depth = input_image image_depth = torch.from_numpy(image_depth).float() image_depth = image_depth.to(device) image_depth = image_depth / 127.5 - 1.0 image_depth = rearrange(image_depth, 'h w c -> 1 c h w') depth = self.model(image_depth)[0] depth_pt = depth.clone() depth_pt -= torch.min(depth_pt) depth_pt /= torch.max(depth_pt) depth_pt = depth_pt.cpu().numpy() depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) if depth_and_normal: depth_np = depth.cpu().numpy() x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) z = np.ones_like(x) * a x[depth_pt < bg_th] = 0 y[depth_pt < bg_th] = 0 normal = np.stack([x, y, z], axis=2) normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] depth_image = HWC3(depth_image) if depth_and_normal: normal_image = HWC3(normal_image) img = resize_image(input_image, image_resolution) H, W, _C = img.shape depth_image = cv2.resize(depth_image, (W, H), interpolation=cv2.INTER_LINEAR) if depth_and_normal: normal_image = cv2.resize(normal_image, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": depth_image = Image.fromarray(depth_image) if depth_and_normal: normal_image = Image.fromarray(normal_image) if opts.control_move_processor: self.model.to('cpu') if depth_and_normal: return depth_image, normal_image else: return depth_image ================================================ FILE: modules/control/proc/midas/api.py ================================================ # based on https://github.com/isl-org/MiDaS import cv2 import os import torch import torch.nn as nn from torchvision.transforms import Compose from .midas.dpt_depth import DPTDepthModel from .midas.midas_net import MidasNet from .midas.midas_net_custom import MidasNet_small from .midas.transforms import Resize, NormalizeImage, PrepareForNet from modules.control.util import annotator_ckpts_path ISL_PATHS = { "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), "midas_v21": "", "midas_v21_small": "", } remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def load_midas_transform(model_type): # https://github.com/isl-org/MiDaS/blob/master/run.py # load transform only if model_type == "dpt_large": # DPT-Large net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_hybrid": # DPT-Hybrid net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "midas_v21": net_w, net_h = 384, 384 resize_mode = "upper_bound" normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif model_type == "midas_v21_small": net_w, net_h = 256, 256 resize_mode = "upper_bound" normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: raise AssertionError(f"model_type '{model_type}' not implemented, use: --model_type large") transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method=resize_mode, image_interpolation_method=cv2.INTER_CUBIC, ), normalization, PrepareForNet(), ] ) return transform def load_model(model_type, model_path=None): # https://github.com/isl-org/MiDaS/blob/master/run.py # load network model_path = model_path or ISL_PATHS[model_type] if model_type == "dpt_large": # DPT-Large model = DPTDepthModel( path=model_path, backbone="vitl16_384", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_hybrid": # DPT-Hybrid if not os.path.exists(model_path): from basicsr.utils.download_util import load_file_from_url load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) model = DPTDepthModel( path=model_path, backbone="vitb_rn50_384", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "midas_v21": model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" normalization = NormalizeImage( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) elif model_type == "midas_v21_small": model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) net_w, net_h = 256, 256 resize_mode = "upper_bound" normalization = NormalizeImage( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") raise AssertionError transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method=resize_mode, image_interpolation_method=cv2.INTER_CUBIC, ), normalization, PrepareForNet(), ] ) return model.eval(), transform class MiDaSInference(nn.Module): MODEL_TYPES_TORCH_HUB = [ "DPT_Large", "DPT_Hybrid", "MiDaS_small" ] MODEL_TYPES_ISL = [ "dpt_large", "dpt_hybrid", "midas_v21", "midas_v21_small", ] def __init__(self, model_type, model_path): super().__init__() assert (model_type in self.MODEL_TYPES_ISL) model, _ = load_model(model_type, model_path) self.model = model self.model.train = disabled_train def forward(self, x): prediction = self.model(x) return prediction ================================================ FILE: modules/control/proc/midas/midas/__init__.py ================================================ ================================================ FILE: modules/control/proc/midas/midas/base_model.py ================================================ import torch class BaseModel(torch.nn.Module): def load(self, path): """Load model from file. Args: path (str): file path """ parameters = torch.load(path, map_location=torch.device('cpu')) if "optimizer" in parameters: parameters = parameters["model"] self.load_state_dict(parameters) ================================================ FILE: modules/control/proc/midas/midas/blocks.py ================================================ import torch import torch.nn as nn from .vit import ( _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384, _make_pretrained_vitb16_384, forward_vit, ) def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): if backbone == "vitl16_384": pretrained = _make_pretrained_vitl16_384( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) elif backbone == "vitb_rn50_384": pretrained = _make_pretrained_vitb_rn50_384( use_pretrained, hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, ) scratch = _make_scratch( [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": pretrained = _make_pretrained_vitb16_384( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") raise AssertionError return pretrained, scratch def _make_scratch(in_shape, out_shape, groups=1, expand=False): scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape if expand is True: out_shape1 = out_shape out_shape2 = out_shape*2 out_shape3 = out_shape*4 out_shape4 = out_shape*8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) return scratch def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): efficientnet = torch.hub.load( "rwightman/gen-efficientnet-pytorch", "tf_efficientnet_lite3", pretrained=use_pretrained, exportable=exportable ) return _make_efficientnet_backbone(efficientnet) def _make_efficientnet_backbone(effnet): pretrained = nn.Module() pretrained.layer1 = nn.Sequential( effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] ) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained def _make_resnet_backbone(resnet): pretrained = nn.Module() pretrained.layer1 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 ) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 pretrained.layer4 = resnet.layer4 return pretrained def _make_pretrained_resnext101_wsl(use_pretrained): resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") return _make_resnet_backbone(resnet) class Interpolate(nn.Module): """Interpolation module. """ def __init__(self, scale_factor, mode, align_corners=False): """Init. Args: scale_factor (float): scaling mode (str): interpolation mode """ super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: interpolated data """ x = self.interp( x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners ) return x class ResidualConvUnit(nn.Module): """Residual convolution module. """ def __init__(self, features): """Init. Args: features (int): number of features """ super().__init__() self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True ) self.relu = nn.ReLU(inplace=True) def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.relu(x) out = self.conv1(out) out = self.relu(out) out = self.conv2(out) return out + x class FeatureFusionBlock(nn.Module): """Feature fusion block. """ def __init__(self, features): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.resConfUnit1 = ResidualConvUnit(features) self.resConfUnit2 = ResidualConvUnit(features) def forward(self, *xs): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: output += self.resConfUnit1(xs[1]) output = self.resConfUnit2(output) output = nn.functional.interpolate( output, scale_factor=2, mode="bilinear", align_corners=True ) return output class ResidualConvUnit_custom(nn.Module): """Residual convolution module. """ def __init__(self, features, activation, bn): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups=1 self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups ) if self.bn is True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.bn is True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn is True: out = self.bn2(out) if self.groups > 1: out = self.conv_merge(out) return self.skip_add.add(out, x) # return out + x class FeatureFusionBlock_custom(nn.Module): """Feature fusion block. """ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): """Init. Args: features (int): number of features """ super(FeatureFusionBlock_custom, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups=1 self.expand = expand out_features = features if self.expand is True: out_features = features//2 self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) # output += res output = self.resConfUnit2(output) output = nn.functional.interpolate( output, scale_factor=2, mode="bilinear", align_corners=self.align_corners ) output = self.out_conv(output) return output ================================================ FILE: modules/control/proc/midas/midas/dpt_depth.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .base_model import BaseModel from .blocks import ( FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit, ) def _make_fusion_block(features, use_bn): return FeatureFusionBlock_custom( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, ) class DPT(BaseModel): def __init__( self, head, features=256, backbone="vitb_rn50_384", readout="project", channels_last=False, use_bn=False, ): super(DPT, self).__init__() self.channels_last = channels_last hooks = { "vitb_rn50_384": [0, 1, 8, 11], "vitb16_384": [2, 5, 8, 11], "vitl16_384": [5, 11, 17, 23], } # Instantiate backbone and reassemble blocks self.pretrained, self.scratch = _make_encoder( backbone, features, False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, hooks=hooks[backbone], use_readout=readout, ) self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) self.scratch.refinenet3 = _make_fusion_block(features, use_bn) self.scratch.refinenet4 = _make_fusion_block(features, use_bn) self.scratch.output_conv = head def forward(self, x): if self.channels_last is True: x.contiguous(memory_format=torch.channels_last) layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv(path_1) return out class DPTDepthModel(DPT): def __init__(self, path=None, non_negative=True, **kwargs): features = kwargs["features"] if "features" in kwargs else 256 head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) super().__init__(head, **kwargs) if path is not None: self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) ================================================ FILE: modules/control/proc/midas/midas/midas_net.py ================================================ """MidashNet: Network for monocular depth estimation trained by mixing several datasets. This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ import torch import torch.nn as nn from .base_model import BaseModel from .blocks import FeatureFusionBlock, Interpolate, _make_encoder class MidasNet(BaseModel): """Network for monocular depth estimation. """ def __init__(self, path=None, features=256, non_negative=True): """Init. Args: path (str, optional): Path to saved model. Defaults to None. features (int, optional): Number of features. Defaults to 256. backbone (str, optional): Backbone network for encoder. Defaults to resnet50 """ print("Loading weights: ", path) super(MidasNet, self).__init__() use_pretrained = False if path is None else True self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) self.scratch.refinenet2 = FeatureFusionBlock(features) self.scratch.refinenet1 = FeatureFusionBlock(features) self.scratch.output_conv = nn.Sequential( nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear"), nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), ) if path: self.load(path) def forward(self, x): """Forward pass. Args: x (tensor): input data (image) Returns: tensor: depth """ layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) ================================================ FILE: modules/control/proc/midas/midas/midas_net_custom.py ================================================ """MidashNet: Network for monocular depth estimation trained by mixing several datasets. This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ import torch import torch.nn as nn from .base_model import BaseModel from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder class MidasNet_small(BaseModel): """Network for monocular depth estimation. """ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, blocks=None): """Init. Args: path (str, optional): Path to saved model. Defaults to None. features (int, optional): Number of features. Defaults to 256. backbone (str, optional): Backbone network for encoder. Defaults to resnet50 """ if blocks is None: blocks = {"expand": True} print("Loading weights: ", path) super(MidasNet_small, self).__init__() use_pretrained = False if path else True self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 features1=features features2=features features3=features features4=features self.expand = False if "expand" in self.blocks and self.blocks['expand'] is True: self.expand = True features1=features features2=features*2 features3=features*4 features4=features*8 self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) self.scratch.activation = nn.ReLU(False) self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) self.scratch.output_conv = nn.Sequential( nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) if path: self.load(path) def forward(self, x): """Forward pass. Args: x (tensor): input data (image) Returns: tensor: depth """ if self.channels_last is True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) def fuse_model(m): prev_previous_type = nn.Identity() prev_previous_name = '' previous_type = nn.Identity() previous_name = '' for name, module in m.named_modules(): if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: # print("FUSED ", prev_previous_name, previous_name, name) torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: # print("FUSED ", prev_previous_name, previous_name) torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: # print("FUSED ", previous_name, name) # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) previous_name = name ================================================ FILE: modules/control/proc/midas/midas/transforms.py ================================================ import math import numpy as np import cv2 def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. Args: sample (dict): sample size (tuple): image size Returns: tuple: new size """ shape = list(sample["disparity"].shape) if shape[0] >= size[0] and shape[1] >= size[1]: return sample scale = [0, 0] scale[0] = size[0] / shape[0] scale[1] = size[1] / shape[1] scale = max(scale) shape[0] = math.ceil(scale * shape[0]) shape[1] = math.ceil(scale * shape[1]) # resize sample["image"] = cv2.resize( sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method ) sample["disparity"] = cv2.resize( sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return tuple(shape) class Resize(object): """Resize sample to given size (width, height). """ def __init__( self, width, height, resize_target=True, keep_aspect_ratio=False, ensure_multiple_of=1, resize_method="lower_bound", image_interpolation_method=cv2.INTER_AREA, ): """Init. Args: width (int): desired output width height (int): desired output height resize_target (bool, optional): True: Resize the full sample (image, mask, target). False: Resize image only. Defaults to True. keep_aspect_ratio (bool, optional): True: Keep the aspect ratio of the input sample. Output sample might not have the given width and height, and resize behaviour depends on the parameter 'resize_method'. Defaults to False. ensure_multiple_of (int, optional): Output width and height is constrained to be multiple of this parameter. Defaults to 1. resize_method (str, optional): "lower_bound": Output will be at least as large as the given size. "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) "minimal": Scale as least as possible. (Output size might be smaller than given size.) Defaults to "lower_bound". """ self.__width = width self.__height = height self.__resize_target = resize_target self.__keep_aspect_ratio = keep_aspect_ratio self.__multiple_of = ensure_multiple_of self.__resize_method = resize_method self.__image_interpolation_method = image_interpolation_method def constrain_to_multiple_of(self, x, min_val=0, max_val=None): y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) if max_val is not None and y > max_val: y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) if y < min_val: y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) return y def get_size(self, width, height): # determine new height and width scale_height = self.__height / height scale_width = self.__width / width if self.__keep_aspect_ratio: if self.__resize_method == "lower_bound": # scale such that output size is lower bound if scale_width > scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "upper_bound": # scale such that output size is upper bound if scale_width < scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "minimal": # scale as least as possbile if abs(1 - scale_width) < abs(1 - scale_height): # fit width scale_height = scale_width else: # fit height scale_width = scale_height else: raise ValueError( f"resize_method {self.__resize_method} not implemented" ) if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of( scale_height * height, min_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, min_val=self.__width ) elif self.__resize_method == "upper_bound": new_height = self.constrain_to_multiple_of( scale_height * height, max_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, max_val=self.__width ) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: raise ValueError(f"resize_method {self.__resize_method} not implemented") return (new_width, new_height) def __call__(self, sample): width, height = self.get_size( sample["image"].shape[1], sample["image"].shape[0] ) # resize sample sample["image"] = cv2.resize( sample["image"], (width, height), interpolation=self.__image_interpolation_method, ) if self.__resize_target: if "disparity" in sample: sample["disparity"] = cv2.resize( sample["disparity"], (width, height), interpolation=cv2.INTER_NEAREST, ) if "depth" in sample: sample["depth"] = cv2.resize( sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return sample class NormalizeImage(object): """Normlize image by given mean and std. """ def __init__(self, mean, std): self.__mean = mean self.__std = std def __call__(self, sample): sample["image"] = (sample["image"] - self.__mean) / self.__std return sample class PrepareForNet(object): """Prepare sample for usage as network input. """ def __init__(self): pass def __call__(self, sample): image = np.transpose(sample["image"], (2, 0, 1)) sample["image"] = np.ascontiguousarray(image).astype(np.float32) if "mask" in sample: sample["mask"] = sample["mask"].astype(np.float32) sample["mask"] = np.ascontiguousarray(sample["mask"]) if "disparity" in sample: disparity = sample["disparity"].astype(np.float32) sample["disparity"] = np.ascontiguousarray(disparity) if "depth" in sample: depth = sample["depth"].astype(np.float32) sample["depth"] = np.ascontiguousarray(depth) return sample ================================================ FILE: modules/control/proc/midas/midas/vit.py ================================================ import torch import torch.nn as nn import timm import types import math import torch.nn.functional as F class Slice(nn.Module): def __init__(self, start_index=1): super(Slice, self).__init__() self.start_index = start_index def forward(self, x): return x[:, self.start_index :] class AddReadout(nn.Module): def __init__(self, start_index=1): super(AddReadout, self).__init__() self.start_index = start_index def forward(self, x): if self.start_index == 2: readout = (x[:, 0] + x[:, 1]) / 2 else: readout = x[:, 0] return x[:, self.start_index :] + readout.unsqueeze(1) class ProjectReadout(nn.Module): def __init__(self, in_features, start_index=1): super(ProjectReadout, self).__init__() self.start_index = start_index self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) def forward(self, x): readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) features = torch.cat((x[:, self.start_index :], readout), -1) return self.project(features) class Transpose(nn.Module): def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x): x = x.transpose(self.dim0, self.dim1) return x def forward_vit(pretrained, x): b, c, h, w = x.shape pretrained.model.forward_flex(x) layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] layer_3 = pretrained.activations["3"] layer_4 = pretrained.activations["4"] layer_1 = pretrained.act_postprocess1[0:2](layer_1) layer_2 = pretrained.act_postprocess2[0:2](layer_2) layer_3 = pretrained.act_postprocess3[0:2](layer_3) layer_4 = pretrained.act_postprocess4[0:2](layer_4) unflatten = nn.Sequential( nn.Unflatten( 2, torch.Size( [ h // pretrained.model.patch_size[1], w // pretrained.model.patch_size[0], ] ), ) ) if layer_1.ndim == 3: layer_1 = unflatten(layer_1) if layer_2.ndim == 3: layer_2 = unflatten(layer_2) if layer_3.ndim == 3: layer_3 = unflatten(layer_3) if layer_4.ndim == 3: layer_4 = unflatten(layer_4) layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) return layer_1, layer_2, layer_3, layer_4 def _resize_pos_embed(self, posemb, gs_h, gs_w): posemb_tok, posemb_grid = ( posemb[:, : self.start_index], posemb[0, self.start_index :], ) gs_old = int(math.sqrt(len(posemb_grid))) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def forward_flex(self, x): b, c, h, w = x.shape pos_embed = self._resize_pos_embed( self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] ) B = x.shape[0] if hasattr(self.patch_embed, "backbone"): x = self.patch_embed.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output return hook def get_readout_oper(vit_features, features, use_readout, start_index=1): if use_readout == "ignore": readout_oper = [Slice(start_index)] * len(features) elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": readout_oper = [ ProjectReadout(vit_features, start_index) for out_feat in features ] else: raise AssertionError("wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'") return readout_oper def _make_vit_b16_backbone( model, features=None, size=None, hooks=None, vit_features=768, use_readout="ignore", start_index=1, ): if hooks is None: hooks = [2, 5, 8, 11] if size is None: size = [384, 384] if features is None: features = [96, 192, 384, 768] pretrained = nn.Module() pretrained.model = model pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) # 32, 48, 136, 384 pretrained.act_postprocess1 = nn.Sequential( readout_oper[0], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[0], out_channels=features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess2 = nn.Sequential( readout_oper[1], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[1], out_channels=features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[2], kernel_size=1, stride=1, padding=0, ), ) pretrained.act_postprocess4 = nn.Sequential( readout_oper[3], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ), ) pretrained.model.start_index = start_index pretrained.model.patch_size = [16, 16] # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) hooks = [5, 11, 17, 23] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[256, 512, 1024, 1024], hooks=hooks, vit_features=1024, use_readout=use_readout, ) def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout ) def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout ) def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model( "vit_deit_base_distilled_patch16_384", pretrained=pretrained ) hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout, start_index=2, ) def _make_vit_b_rn50_backbone( model, features=None, size=None, hooks=None, vit_features=768, use_vit_only=False, use_readout="ignore", start_index=1, ): if hooks is None: hooks = [0, 1, 8, 11] if size is None: size = [384, 384] if features is None: features = [256, 512, 768, 768] pretrained = nn.Module() pretrained.model = model if use_vit_only is True: pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( get_activation("1") ) pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( get_activation("2") ) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) if use_vit_only is True: pretrained.act_postprocess1 = nn.Sequential( readout_oper[0], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[0], out_channels=features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess2 = nn.Sequential( readout_oper[1], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[1], out_channels=features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ), ) else: pretrained.act_postprocess1 = nn.Sequential( nn.Identity(), nn.Identity(), nn.Identity() ) pretrained.act_postprocess2 = nn.Sequential( nn.Identity(), nn.Identity(), nn.Identity() ) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[2], kernel_size=1, stride=1, padding=0, ), ) pretrained.act_postprocess4 = nn.Sequential( readout_oper[3], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ), ) pretrained.model.start_index = start_index pretrained.model.patch_size = [16, 16] # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def _make_pretrained_vitb_rn50_384( pretrained, use_readout="ignore", hooks=None, use_vit_only=False ): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks is None else hooks return _make_vit_b_rn50_backbone( model, features=[256, 512, 768, 768], size=[384, 384], hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, ) ================================================ FILE: modules/control/proc/midas/utils.py ================================================ """Utils for monoDepth.""" import sys import re import numpy as np import cv2 import torch def read_pfm(path): """Read pfm file. Args: path (str): path to file Returns: tuple: (data, scale) """ with open(path, "rb") as file: color = None width = None height = None scale = None endian = None header = file.readline().rstrip() if header.decode("ascii") == "PF": color = True elif header.decode("ascii") == "Pf": color = False else: raise Exception("Not a PFM file: " + path) dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) if dim_match: width, height = list(map(int, dim_match.groups())) else: raise Exception("Malformed PFM header.") scale = float(file.readline().decode("ascii").rstrip()) if scale < 0: # little-endian endian = "<" scale = -scale else: # big-endian endian = ">" data = np.fromfile(file, endian + "f") shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data, scale def write_pfm(path, image, scale=1): """Write pfm file. Args: path (str): pathto file image (array): data scale (int, optional): Scale. Defaults to 1. """ with open(path, "wb") as file: color = None if image.dtype.name != "float32": raise Exception("Image dtype must be float32.") image = np.flipud(image) if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True elif ( len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 ): # greyscale color = False else: raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") file.write("PF\n" if color else "Pf\n".encode()) file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) endian = image.dtype.byteorder if endian == "<" or endian == "=" and sys.byteorder == "little": scale = -scale file.write("%f\n".encode() % scale) image.tofile(file) def read_image(path): """Read image and output RGB image (0-1). Args: path (str): path to file Returns: array: RGB image (0-1) """ img = cv2.imread(path) if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 return img def resize_image(img): """Resize image and make it fit for network. Args: img (array): image Returns: tensor: data ready for network """ height_orig = img.shape[0] width_orig = img.shape[1] if width_orig > height_orig: scale = width_orig / 384 else: scale = height_orig / 384 height = (np.ceil(height_orig / scale / 32) * 32).astype(int) width = (np.ceil(width_orig / scale / 32) * 32).astype(int) img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) img_resized = ( torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() ) img_resized = img_resized.unsqueeze(0) return img_resized def resize_depth(depth, width, height): """Resize depth map and bring to CPU (numpy). Args: depth (tensor): depth width (int): image width height (int): image height Returns: array: processed depth """ depth = torch.squeeze(depth[0, :, :, :]).to("cpu") depth_resized = cv2.resize( depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC ) return depth_resized def write_depth(path, depth, bits=1): """Write depth map to pfm and png file. Args: path (str): filepath without extension depth (array): depth """ write_pfm(path + ".pfm", depth.astype(np.float32)) depth_min = depth.min() depth_max = depth.max() max_val = (2**(8*bits))-1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) else: out = np.zeros(depth.shape, dtype=depth.type) if bits == 1: cv2.imwrite(path + ".png", out.astype("uint8")) elif bits == 2: cv2.imwrite(path + ".png", out.astype("uint16")) return ================================================ FILE: modules/control/proc/mlsd/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2021-present NAVER Corp. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: modules/control/proc/mlsd/__init__.py ================================================ import os import cv2 import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image from .models.mbv2_mlsd_large import MobileV2_MLSD_Large from .utils import pred_lines class MLSDdetector: def __init__(self, model): self.model = model @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): if pretrained_model_or_path == "lllyasviel/ControlNet": filename = filename or "annotator/ckpts/mlsd_large_512_fp32.pth" else: filename = filename or "mlsd_large_512_fp32.pth" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) model = MobileV2_MLSD_Large() model.load_state_dict(torch.load(model_path), strict=True) model.eval() return cls(model) def to(self, device): self.model.to(device) return self def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): self.model.to(devices.device) if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 img = input_image img_output = np.zeros_like(img) try: lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) for line in lines: x_start, y_start, x_end, y_end = [int(val) for val in line] cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) except Exception: pass detected_map = img_output[:, :, 0] detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) if opts.control_move_processor: self.model.to('cpu') return detected_map ================================================ FILE: modules/control/proc/mlsd/models/__init__.py ================================================ ================================================ FILE: modules/control/proc/mlsd/models/mbv2_mlsd_large.py ================================================ import os import sys import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo from torch.nn import functional as F class BlockTypeA(nn.Module): def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): super(BlockTypeA, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c2, out_c2, kernel_size=1), nn.BatchNorm2d(out_c2), nn.ReLU(inplace=True) ) self.conv2 = nn.Sequential( nn.Conv2d(in_c1, out_c1, kernel_size=1), nn.BatchNorm2d(out_c1), nn.ReLU(inplace=True) ) self.upscale = upscale def forward(self, a, b): b = self.conv1(b) a = self.conv2(a) if self.upscale: b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) return torch.cat((a, b), dim=1) class BlockTypeB(nn.Module): def __init__(self, in_c, out_c): super(BlockTypeB, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), nn.BatchNorm2d(in_c), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU() ) def forward(self, x): x = self.conv1(x) + x x = self.conv2(x) return x class BlockTypeC(nn.Module): def __init__(self, in_c, out_c): super(BlockTypeC, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), nn.BatchNorm2d(in_c), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), nn.BatchNorm2d(in_c), nn.ReLU() ) self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) return x def _make_divisible(v, divisor, min_value=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: :param divisor: :param min_value: :return: """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v class ConvBNReLU(nn.Sequential): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): self.channel_pad = out_planes - in_planes self.stride = stride #padding = (kernel_size - 1) // 2 # TFLite uses slightly different padding than PyTorch if stride == 2: padding = 0 else: padding = (kernel_size - 1) // 2 super(ConvBNReLU, self).__init__( nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), nn.BatchNorm2d(out_planes), nn.ReLU6(inplace=True) ) self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) def forward(self, x): # TFLite uses different padding if self.stride == 2: x = F.pad(x, (0, 1, 0, 1), "constant", 0) #print(x.shape) for module in self: if not isinstance(module, nn.MaxPool2d): x = module(x) return x class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(round(inp * expand_ratio)) self.use_res_connect = self.stride == 1 and inp == oup layers = [] if expand_ratio != 1: # pw layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) layers.extend([ # dw ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ]) self.conv = nn.Sequential(*layers) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileNetV2(nn.Module): def __init__(self, pretrained=True): """ MobileNet V2 main class Args: num_classes (int): Number of classes width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount inverted_residual_setting: Network structure round_nearest (int): Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding block: Module specifying inverted residual building block for mobilenet """ super(MobileNetV2, self).__init__() block = InvertedResidual input_channel = 32 last_channel = 1280 width_mult = 1.0 round_nearest = 8 inverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], #[6, 160, 3, 2], #[6, 320, 1, 1], ] # only check the first element, assuming user knows t,c,n,s are required if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: raise ValueError("inverted_residual_setting should be non-empty " "or a 4-element list, got {}".format(inverted_residual_setting)) # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) features = [ConvBNReLU(4, input_channel, stride=2)] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: output_channel = _make_divisible(c * width_mult, round_nearest) for i in range(n): stride = s if i == 0 else 1 features.append(block(input_channel, output_channel, stride, expand_ratio=t)) input_channel = output_channel self.features = nn.Sequential(*features) self.fpn_selected = [1, 3, 6, 10, 13] # weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) if pretrained: self._load_pretrained_model() def _forward_impl(self, x): # This exists since TorchScript doesn't support inheritance, so the superclass method # (this one) needs to have a name other than `forward` that can be accessed in a subclass fpn_features = [] for i, f in enumerate(self.features): if i > self.fpn_selected[-1]: break x = f(x) if i in self.fpn_selected: fpn_features.append(x) c1, c2, c3, c4, c5 = fpn_features return c1, c2, c3, c4, c5 def forward(self, x): return self._forward_impl(x) def _load_pretrained_model(self): pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') model_dict = {} state_dict = self.state_dict() for k, v in pretrain_dict.items(): if k in state_dict: model_dict[k] = v state_dict.update(model_dict) self.load_state_dict(state_dict) class MobileV2_MLSD_Large(nn.Module): def __init__(self): super(MobileV2_MLSD_Large, self).__init__() self.backbone = MobileNetV2(pretrained=False) ## A, B self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, out_c1= 64, out_c2=64, upscale=False) self.block16 = BlockTypeB(128, 64) ## A, B self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, out_c1= 64, out_c2= 64) self.block18 = BlockTypeB(128, 64) ## A, B self.block19 = BlockTypeA(in_c1=24, in_c2=64, out_c1=64, out_c2=64) self.block20 = BlockTypeB(128, 64) ## A, B, C self.block21 = BlockTypeA(in_c1=16, in_c2=64, out_c1=64, out_c2=64) self.block22 = BlockTypeB(128, 64) self.block23 = BlockTypeC(64, 16) def forward(self, x): c1, c2, c3, c4, c5 = self.backbone(x) x = self.block15(c4, c5) x = self.block16(x) x = self.block17(c3, x) x = self.block18(x) x = self.block19(c2, x) x = self.block20(x) x = self.block21(c1, x) x = self.block22(x) x = self.block23(x) x = x[:, 7:, :, :] return x ================================================ FILE: modules/control/proc/mlsd/models/mbv2_mlsd_tiny.py ================================================ import os import sys import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo from torch.nn import functional as F class BlockTypeA(nn.Module): def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): super(BlockTypeA, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c2, out_c2, kernel_size=1), nn.BatchNorm2d(out_c2), nn.ReLU(inplace=True) ) self.conv2 = nn.Sequential( nn.Conv2d(in_c1, out_c1, kernel_size=1), nn.BatchNorm2d(out_c1), nn.ReLU(inplace=True) ) self.upscale = upscale def forward(self, a, b): b = self.conv1(b) a = self.conv2(a) b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) return torch.cat((a, b), dim=1) class BlockTypeB(nn.Module): def __init__(self, in_c, out_c): super(BlockTypeB, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), nn.BatchNorm2d(in_c), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU() ) def forward(self, x): x = self.conv1(x) + x x = self.conv2(x) return x class BlockTypeC(nn.Module): def __init__(self, in_c, out_c): super(BlockTypeC, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), nn.BatchNorm2d(in_c), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), nn.BatchNorm2d(in_c), nn.ReLU() ) self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) return x def _make_divisible(v, divisor, min_value=None): """ This function is taken from the original tf repo. It ensures that all layers have a channel number that is divisible by 8 It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py :param v: :param divisor: :param min_value: :return: """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v class ConvBNReLU(nn.Sequential): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): self.channel_pad = out_planes - in_planes self.stride = stride #padding = (kernel_size - 1) // 2 # TFLite uses slightly different padding than PyTorch if stride == 2: padding = 0 else: padding = (kernel_size - 1) // 2 super(ConvBNReLU, self).__init__( nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), nn.BatchNorm2d(out_planes), nn.ReLU6(inplace=True) ) self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) def forward(self, x): # TFLite uses different padding if self.stride == 2: x = F.pad(x, (0, 1, 0, 1), "constant", 0) #print(x.shape) for module in self: if not isinstance(module, nn.MaxPool2d): x = module(x) return x class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, expand_ratio): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(round(inp * expand_ratio)) self.use_res_connect = self.stride == 1 and inp == oup layers = [] if expand_ratio != 1: # pw layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) layers.extend([ # dw ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), ]) self.conv = nn.Sequential(*layers) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileNetV2(nn.Module): def __init__(self, pretrained=True): """ MobileNet V2 main class Args: num_classes (int): Number of classes width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount inverted_residual_setting: Network structure round_nearest (int): Round the number of channels in each layer to be a multiple of this number Set to 1 to turn off rounding block: Module specifying inverted residual building block for mobilenet """ super(MobileNetV2, self).__init__() block = InvertedResidual input_channel = 32 last_channel = 1280 width_mult = 1.0 round_nearest = 8 inverted_residual_setting = [ # t, c, n, s [1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], #[6, 96, 3, 1], #[6, 160, 3, 2], #[6, 320, 1, 1], ] # only check the first element, assuming user knows t,c,n,s are required if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: raise ValueError("inverted_residual_setting should be non-empty " "or a 4-element list, got {}".format(inverted_residual_setting)) # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) features = [ConvBNReLU(4, input_channel, stride=2)] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: output_channel = _make_divisible(c * width_mult, round_nearest) for i in range(n): stride = s if i == 0 else 1 features.append(block(input_channel, output_channel, stride, expand_ratio=t)) input_channel = output_channel self.features = nn.Sequential(*features) self.fpn_selected = [3, 6, 10] # weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) #if pretrained: # self._load_pretrained_model() def _forward_impl(self, x): # This exists since TorchScript doesn't support inheritance, so the superclass method # (this one) needs to have a name other than `forward` that can be accessed in a subclass fpn_features = [] for i, f in enumerate(self.features): if i > self.fpn_selected[-1]: break x = f(x) if i in self.fpn_selected: fpn_features.append(x) c2, c3, c4 = fpn_features return c2, c3, c4 def forward(self, x): return self._forward_impl(x) def _load_pretrained_model(self): pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') model_dict = {} state_dict = self.state_dict() for k, v in pretrain_dict.items(): if k in state_dict: model_dict[k] = v state_dict.update(model_dict) self.load_state_dict(state_dict) class MobileV2_MLSD_Tiny(nn.Module): def __init__(self): super(MobileV2_MLSD_Tiny, self).__init__() self.backbone = MobileNetV2(pretrained=True) self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, out_c1= 64, out_c2=64) self.block13 = BlockTypeB(128, 64) self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, out_c1= 32, out_c2= 32) self.block15 = BlockTypeB(64, 64) self.block16 = BlockTypeC(64, 16) def forward(self, x): c2, c3, c4 = self.backbone(x) x = self.block12(c3, c4) x = self.block13(x) x = self.block14(c2, x) x = self.block15(x) x = self.block16(x) x = x[:, 7:, :, :] #print(x.shape) x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) return x ================================================ FILE: modules/control/proc/mlsd/utils.py ================================================ ''' modified by lihaoweicv pytorch version ''' ''' M-LSD Copyright 2021-present NAVER Corp. Apache License v2.0 ''' import os import numpy as np import cv2 import torch from torch.nn import functional as F def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): ''' tpMap: center: tpMap[1, 0, :, :] displacement: tpMap[1, 1:5, :, :] ''' b, c, h, w = tpMap.shape assert b==1, 'only support bsize==1' displacement = tpMap[:, 1:5, :, :][0] center = tpMap[:, 0, :, :] heat = torch.sigmoid(center) hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) keep = (hmax == heat).float() heat = heat * keep heat = heat.reshape(-1, ) scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) yy = torch.floor_divide(indices, w).unsqueeze(-1) xx = torch.fmod(indices, w).unsqueeze(-1) ptss = torch.cat((yy, xx),dim=-1) ptss = ptss.detach().cpu().numpy() scores = scores.detach().cpu().numpy() displacement = displacement.detach().cpu().numpy() displacement = displacement.transpose((1,2,0)) return ptss, scores, displacement def pred_lines(image, model, input_shape=None, score_thr=0.10, dist_thr=20.0): if input_shape is None: input_shape = [512, 512] h, w, _ = image.shape device = next(iter(model.parameters())).device h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), np.ones([input_shape[0], input_shape[1], 1])], axis=-1) resized_image = resized_image.transpose((2,0,1)) batch_image = np.expand_dims(resized_image, axis=0).astype('float32') batch_image = (batch_image / 127.5) - 1.0 batch_image = torch.from_numpy(batch_image).float() batch_image = batch_image.to(device) outputs = model(batch_image) pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) start = vmap[:, :, :2] end = vmap[:, :, 2:] dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) segments_list = [] for center, score in zip(pts, pts_score): y, x = center distance = dist_map[y, x] if score > score_thr and distance > dist_thr: disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] x_start = x + disp_x_start y_start = y + disp_y_start x_end = x + disp_x_end y_end = y + disp_y_end segments_list.append([x_start, y_start, x_end, y_end]) lines = 2 * np.array(segments_list) # 256 > 512 lines[:, 0] = lines[:, 0] * w_ratio lines[:, 1] = lines[:, 1] * h_ratio lines[:, 2] = lines[:, 2] * w_ratio lines[:, 3] = lines[:, 3] * h_ratio return lines def pred_squares(image, model, input_shape=None, params=None): ''' shape = [height, width] ''' if params is None: params = {'score': 0.06, 'outside_ratio': 0.28, 'inside_ratio': 0.45, 'w_overlap': 0.0, 'w_degree': 1.95, 'w_length': 0.0, 'w_area': 1.86, 'w_center': 0.14} if input_shape is None: input_shape = [512, 512] h, w, _ = image.shape original_shape = [h, w] device = next(iter(model.parameters())).device resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), np.ones([input_shape[0], input_shape[1], 1])], axis=-1) resized_image = resized_image.transpose((2, 0, 1)) batch_image = np.expand_dims(resized_image, axis=0).astype('float32') batch_image = (batch_image / 127.5) - 1.0 batch_image = torch.from_numpy(batch_image).float().to(device) outputs = model(batch_image) pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) start = vmap[:, :, :2] # (x, y) end = vmap[:, :, 2:] # (x, y) dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) junc_list = [] segments_list = [] for junc, score in zip(pts, pts_score): y, x = junc distance = dist_map[y, x] if score > params['score'] and distance > 20.0: junc_list.append([x, y]) disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] d_arrow = 1.0 x_start = x + d_arrow * disp_x_start y_start = y + d_arrow * disp_y_start x_end = x + d_arrow * disp_x_end y_end = y + d_arrow * disp_y_end segments_list.append([x_start, y_start, x_end, y_end]) segments = np.array(segments_list) ####### post processing for squares # 1. get unique lines point = np.array([[0, 0]]) point = point[0] start = segments[:, :2] end = segments[:, 2:] diff = start - end a = diff[:, 1] b = -diff[:, 0] c = a * start[:, 0] + b * start[:, 1] d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi theta[theta < 0.0] += 180 hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) d_quant = 1 theta_quant = 2 hough[:, 0] //= d_quant hough[:, 1] //= theta_quant _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 yx_indices = hough[indices, :].astype('int32') acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices acc_map_np = acc_map # acc_map = acc_map[None, :, :, None] # # ### fast suppression using tensorflow op # acc_map = tf.constant(acc_map, dtype=tf.float32) # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) # flatten_acc_map = tf.reshape(acc_map, [1, -1]) # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) # _, h, w, _ = acc_map.shape # y = tf.expand_dims(topk_indices // w, axis=-1) # x = tf.expand_dims(topk_indices % w, axis=-1) # yx = tf.concat([y, x], axis=-1) ### fast suppression using pytorch op acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) _,_, h, w = acc_map.shape max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) acc_map = acc_map * ( (acc_map == max_acc_map).float() ) flatten_acc_map = acc_map.reshape([-1, ]) scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) xx = torch.fmod(indices, w).unsqueeze(-1) yx = torch.cat((yy, xx), dim=-1) yx = yx.detach().cpu().numpy() topk_values = scores.detach().cpu().numpy() indices = idx_map[yx[:, 0], yx[:, 1]] basis = 5 // 2 merged_segments = [] for yx_pt, max_indice, value in zip(yx, indices, topk_values): y, x = yx_pt if max_indice == -1 or value == 0: continue segment_list = [] for y_offset in range(-basis, basis + 1): for x_offset in range(-basis, basis + 1): indice = idx_map[y + y_offset, x + x_offset] cnt = int(acc_map_np[y + y_offset, x + x_offset]) if indice != -1: segment_list.append(segments[indice]) if cnt > 1: check_cnt = 1 current_hough = hough[indice] for new_indice, new_hough in enumerate(hough): if (current_hough == new_hough).all() and indice != new_indice: segment_list.append(segments[new_indice]) check_cnt += 1 if check_cnt == cnt: break group_segments = np.array(segment_list).reshape([-1, 2]) sorted_group_segments = np.sort(group_segments, axis=0) x_min, y_min = sorted_group_segments[0, :] x_max, y_max = sorted_group_segments[-1, :] deg = theta[max_indice] if deg >= 90: merged_segments.append([x_min, y_max, x_max, y_min]) else: merged_segments.append([x_min, y_min, x_max, y_max]) # 2. get intersections new_segments = np.array(merged_segments) # (x1, y1, x2, y2) start = new_segments[:, :2] # (x1, y1) end = new_segments[:, 2:] # (x2, y2) new_centers = (start + end) / 2.0 diff = start - end dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) # ax + by = c a = diff[:, 1] b = -diff[:, 0] c = a * start[:, 0] + b * start[:, 1] pre_det = a[:, None] * b[None, :] det = pre_det - np.transpose(pre_det) pre_inter_y = a[:, None] * c[None, :] inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) pre_inter_x = c[:, None] * b[None, :] inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') # 3. get corner information # 3.1 get distance ''' dist_segments: | dist(0), dist(1), dist(2), ...| dist_inter_to_segment1: | dist(inter,0), dist(inter,0), dist(inter,0), ... | | dist(inter,1), dist(inter,1), dist(inter,1), ... | ... dist_inter_to_semgnet2: | dist(inter,0), dist(inter,1), dist(inter,2), ... | | dist(inter,0), dist(inter,1), dist(inter,2), ... | ... ''' dist_inter_to_segment1_start = np.sqrt( np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] dist_inter_to_segment1_end = np.sqrt( np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] dist_inter_to_segment2_start = np.sqrt( np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] dist_inter_to_segment2_end = np.sqrt( np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] # sort ascending dist_inter_to_segment1 = np.sort( np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), axis=-1) # [n_batch, n_batch, 2] dist_inter_to_segment2 = np.sort( np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), axis=-1) # [n_batch, n_batch, 2] # 3.2 get degree inter_to_start = new_centers[:, None, :] - inter_pts deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi deg_inter_to_start[deg_inter_to_start < 0.0] += 360 inter_to_end = new_centers[None, :, :] - inter_pts deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi deg_inter_to_end[deg_inter_to_end < 0.0] += 360 ''' B -- G | | C -- R B : blue / G: green / C: cyan / R: red 0 -- 1 | | 3 -- 2 ''' # rename variables deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end # sort deg ascending deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) deg_diff_map = np.abs(deg1_map - deg2_map) # we only consider the smallest degree of intersect deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] # define available degree range deg_range = [60, 120] corner_dict = {corner_info: [] for corner_info in range(4)} inter_points = [] for i in range(inter_pts.shape[0]): for j in range(i + 1, inter_pts.shape[1]): # i, j > line index, always i < j x, y = inter_pts[i, j, :] deg1, deg2 = deg_sort[i, j, :] deg_diff = deg_diff_map[i, j] check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] outside_ratio = params['outside_ratio'] # over ratio >>> drop it! inside_ratio = params['inside_ratio'] # over ratio >>> drop it! check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) if check_degree and check_distance: corner_info = None if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): corner_info, _color_info = 0, 'blue' elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): corner_info, _color_info = 1, 'green' elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): corner_info, _color_info = 2, 'black' elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): corner_info, _color_info = 3, 'cyan' else: corner_info, _color_info = 4, 'red' # we don't use it continue corner_dict[corner_info].append([x, y, i, j]) inter_points.append([x, y]) square_list = [] connect_list = [] segments_list = [] for corner0 in corner_dict[0]: for corner1 in corner_dict[1]: connect01 = False for corner0_line in corner0[2:]: if corner0_line in corner1[2:]: connect01 = True break if connect01: for corner2 in corner_dict[2]: connect12 = False for corner1_line in corner1[2:]: if corner1_line in corner2[2:]: connect12 = True break if connect12: for corner3 in corner_dict[3]: connect23 = False for corner2_line in corner2[2:]: if corner2_line in corner3[2:]: connect23 = True break if connect23: for corner3_line in corner3[2:]: if corner3_line in corner0[2:]: # SQUARE!!! ''' 0 -- 1 | | 3 -- 2 square_list: order: 0 > 1 > 2 > 3 | x0, y0, x1, y1, x2, y2, x3, y3 | | x0, y0, x1, y1, x2, y2, x3, y3 | ... connect_list: order: 01 > 12 > 23 > 30 | line_idx01, line_idx12, line_idx23, line_idx30 | | line_idx01, line_idx12, line_idx23, line_idx30 | ... segments_list: order: 0 > 1 > 2 > 3 | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | ... ''' square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) def check_outside_inside(segments_info, connect_idx): # return 'outside or inside', min distance, cover_param, peri_param if connect_idx == segments_info[0]: check_dist_mat = dist_inter_to_segment1 else: check_dist_mat = dist_inter_to_segment2 i, j = segments_info min_dist, max_dist = check_dist_mat[i, j, :] connect_dist = dist_segments[connect_idx] if max_dist > connect_dist: return 'outside', min_dist, 0, 1 else: return 'inside', min_dist, -1, -1 try: map_size = input_shape[0] / 2 squares = np.array(square_list).reshape([-1, 4, 2]) score_array = [] connect_array = np.array(connect_list) segments_array = np.array(segments_list).reshape([-1, 4, 2]) # get degree of corners: squares_rollup = np.roll(squares, 1, axis=1) squares_rolldown = np.roll(squares, -1, axis=1) vec1 = squares_rollup - squares normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) vec2 = squares_rolldown - squares normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] # get square score overlap_scores = [] degree_scores = [] length_scores = [] for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): ''' 0 -- 1 | | 3 -- 2 # segments: [4, 2] # connects: [4] ''' ###################################### OVERLAP SCORES cover = 0 perimeter = 0 # check 0 > 1 > 2 > 3 square_length = [] for start_idx in range(4): end_idx = (start_idx + 1) % 4 connect_idx = connects[start_idx] # segment idx of segment01 start_segments = segments[start_idx] end_segments = segments[end_idx] square[start_idx] square[end_idx] # check whether outside or inside start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, connect_idx) end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min square_length.append( dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) overlap_scores.append(cover / perimeter) ###################################### ###################################### DEGREE SCORES ''' deg0 vs deg2 deg1 vs deg3 ''' deg0, deg1, deg2, deg3 = degree deg_ratio1 = deg0 / deg2 if deg_ratio1 > 1.0: deg_ratio1 = 1 / deg_ratio1 deg_ratio2 = deg1 / deg3 if deg_ratio2 > 1.0: deg_ratio2 = 1 / deg_ratio2 degree_scores.append((deg_ratio1 + deg_ratio2) / 2) ###################################### ###################################### LENGTH SCORES ''' len0 vs len2 len1 vs len3 ''' len0, len1, len2, len3 = square_length len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 length_scores.append((len_ratio1 + len_ratio2) / 2) ###################################### overlap_scores = np.array(overlap_scores) overlap_scores /= np.max(overlap_scores) degree_scores = np.array(degree_scores) # degree_scores /= np.max(degree_scores) length_scores = np.array(length_scores) ###################################### AREA SCORES area_scores = np.reshape(squares, [-1, 4, 2]) area_x = area_scores[:, :, 0] area_y = area_scores[:, :, 1] correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) area_scores = 0.5 * np.abs(area_scores + correction) area_scores /= (map_size * map_size) # np.max(area_scores) ###################################### ###################################### CENTER SCORES centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] # squares: [n, 4, 2] square_centers = np.mean(squares, axis=1) # [n, 2] center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) center_scores = center2center / (map_size / np.sqrt(2.0)) ''' score_w = [overlap, degree, area, center, length] ''' score_array = params['w_overlap'] * overlap_scores \ + params['w_degree'] * degree_scores \ + params['w_area'] * area_scores \ - params['w_center'] * center_scores \ + params['w_length'] * length_scores sorted_idx = np.argsort(score_array)[::-1] score_array = score_array[sorted_idx] squares = squares[sorted_idx] except Exception: pass '''return list merged_lines, squares, scores ''' try: new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] except Exception: new_segments = [] try: squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] except Exception: squares = [] score_array = [] try: inter_points = np.array(inter_points) inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] except Exception: inter_points = [] return new_segments, squares, score_array, inter_points ================================================ FILE: modules/control/proc/normalbae/LICENSE ================================================ MIT License Copyright (c) 2022 Caroline Chan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: modules/control/proc/normalbae/__init__.py ================================================ import os import types import cv2 import numpy as np import torch import torchvision.transforms as transforms from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image from .nets.NNET import NNET # load model def load_checkpoint(fpath, model): ckpt = torch.load(fpath, map_location='cpu')['model'] load_dict = {} for k, v in ckpt.items(): if k.startswith('module.'): k_ = k.replace('module.', '') load_dict[k_] = v else: load_dict[k] = v model.load_state_dict(load_dict) return model class NormalBaeDetector: def __init__(self, model): self.model = model self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): filename = filename or "scannet.pt" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) args = types.SimpleNamespace() args.mode = 'client' args.architecture = 'BN' args.pretrained = 'scannet' args.sampling_ratio = 0.4 args.importance_ratio = 0.7 model = NNET(args) model = load_checkpoint(model_path, model) model.eval() return cls(model) def to(self, device): self.model.to(device) return self def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 image_normal = input_image image_normal = torch.from_numpy(image_normal).float().to(device) image_normal = image_normal / 255.0 image_normal = rearrange(image_normal, 'h w c -> 1 c h w') image_normal = self.norm(image_normal) normal = self.model(image_normal) normal = normal[0][-1][:, :3] # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5 # d = torch.maximum(d, torch.ones_like(d) * 1e-5) # normal /= d normal = ((normal + 1) * 0.5).clip(0, 1) normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy() normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) detected_map = normal_image detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) if opts.control_move_processor: self.model.to('cpu') return detected_map ================================================ FILE: modules/control/proc/normalbae/nets/NNET.py ================================================ import torch.nn as nn from .submodules.encoder import Encoder from .submodules.decoder import Decoder class NNET(nn.Module): def __init__(self, args): super(NNET, self).__init__() self.encoder = Encoder() self.decoder = Decoder(args) def get_1x_lr_params(self): # lr/10 learning rate return self.encoder.parameters() def get_10x_lr_params(self): # lr learning rate return self.decoder.parameters() def forward(self, img, **kwargs): return self.decoder(self.encoder(img), **kwargs) ================================================ FILE: modules/control/proc/normalbae/nets/__init__.py ================================================ ================================================ FILE: modules/control/proc/normalbae/nets/baseline.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .submodules.submodules import UpSampleBN, norm_normalize # This is the baseline encoder-decoder we used in the ablation study class NNET(nn.Module): def __init__(self, args=None): super(NNET, self).__init__() self.encoder = Encoder() self.decoder = Decoder(num_classes=4) def forward(self, x, **kwargs): out = self.decoder(self.encoder(x), **kwargs) # Bilinearly upsample the output to match the input resolution up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False) # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa) up_out = norm_normalize(up_out) return up_out def get_1x_lr_params(self): # lr/10 learning rate return self.encoder.parameters() def get_10x_lr_params(self): # lr learning rate modules = [self.decoder] for m in modules: yield from m.parameters() # Encoder class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() basemodel_name = 'tf_efficientnet_b5_ap' basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) # Remove last layer basemodel.global_pool = nn.Identity() basemodel.classifier = nn.Identity() self.original_model = basemodel def forward(self, x): features = [x] for k, v in self.original_model._modules.items(): if (k == 'blocks'): for _ki, vi in v._modules.items(): features.append(vi(features[-1])) else: features.append(v(features[-1])) return features # Decoder (no pixel-wise MLP, no uncertainty-guided sampling) class Decoder(nn.Module): def __init__(self, num_classes=4): super(Decoder, self).__init__() self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1) def forward(self, features): x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] x_d0 = self.conv2(x_block4) x_d1 = self.up1(x_d0, x_block3) x_d2 = self.up2(x_d1, x_block2) x_d3 = self.up3(x_d2, x_block1) x_d4 = self.up4(x_d3, x_block0) out = self.conv3(x_d4) return out ================================================ FILE: modules/control/proc/normalbae/nets/submodules/__init__.py ================================================ ================================================ FILE: modules/control/proc/normalbae/nets/submodules/decoder.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points class Decoder(nn.Module): def __init__(self, args): super(Decoder, self).__init__() # hyper-parameter for sampling self.sampling_ratio = args.sampling_ratio self.importance_ratio = args.importance_ratio # feature-map self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) if args.architecture == 'BN': self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) elif args.architecture == 'GN': self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024) self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512) self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256) self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128) else: raise Exception('invalid architecture') # produces 1/8 res output self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) # produces 1/4 res output self.out_conv_res4 = nn.Sequential( nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 4, kernel_size=1), ) # produces 1/2 res output self.out_conv_res2 = nn.Sequential( nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 4, kernel_size=1), ) # produces 1/1 res output self.out_conv_res1 = nn.Sequential( nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), nn.Conv1d(128, 4, kernel_size=1), ) def forward(self, features, gt_norm_mask=None, mode='test'): x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] # generate feature-map x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res # 1/8 res output out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output ################################################################################################################ # out_res4 ################################################################################################################ if mode == 'train': # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160] out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) B, _, H, W = out_res8_res4.shape # samples: [B, 1, N, 2] point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask, sampling_ratio=self.sampling_ratio, beta=self.importance_ratio) # output (needed for evaluation / visualization) out_res4 = out_res8_res4 # grid_sample feature-map feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N) init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N) feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N) # prediction (needed to compute loss) samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N) samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized for i in range(B): out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :] else: # grid_sample feature-map feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True) init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) B, _, H, W = feat_map.shape # try all pixels out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N) out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized out_res4 = out_res4.view(B, 4, H, W) samples_pred_res4 = point_coords_res4 = None ################################################################################################################ # out_res2 ################################################################################################################ if mode == 'train': # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) B, _, H, W = out_res4_res2.shape # samples: [B, 1, N, 2] point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask, sampling_ratio=self.sampling_ratio, beta=self.importance_ratio) # output (needed for evaluation / visualization) out_res2 = out_res4_res2 # grid_sample feature-map feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N) init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N) feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N) # prediction (needed to compute loss) samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N) samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized for i in range(B): out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :] else: # grid_sample feature-map feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True) init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) B, _, H, W = feat_map.shape out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N) out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized out_res2 = out_res2.view(B, 4, H, W) samples_pred_res2 = point_coords_res2 = None ################################################################################################################ # out_res1 ################################################################################################################ if mode == 'train': # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) B, _, H, W = out_res2_res1.shape # samples: [B, 1, N, 2] point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask, sampling_ratio=self.sampling_ratio, beta=self.importance_ratio) # output (needed for evaluation / visualization) out_res1 = out_res2_res1 # grid_sample feature-map feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N) init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N) feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N) # prediction (needed to compute loss) samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N) samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized for i in range(B): out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :] else: # grid_sample feature-map feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True) init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) B, _, H, W = feat_map.shape out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N) out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized out_res1 = out_res1.view(B, 4, H, W) samples_pred_res1 = point_coords_res1 = None return [out_res8, out_res4, out_res2, out_res1], \ [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \ [None, point_coords_res4, point_coords_res2, point_coords_res1] ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md ================================================ # Model Performance Benchmarks All benchmarks run as per: ``` python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3 python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb ``` ## EfficientNet-B0 ### Unoptimized ``` Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897 Time per operator type: 29.7378 ms. 60.5145%. Conv 12.1785 ms. 24.7824%. Sigmoid 3.62811 ms. 7.38297%. SpatialBN 2.98444 ms. 6.07314%. Mul 0.326902 ms. 0.665225%. AveragePool 0.197317 ms. 0.401528%. FC 0.0852877 ms. 0.173555%. Add 0.0032607 ms. 0.00663532%. Squeeze 49.1416 ms in Total FLOP per operator type: 0.76907 GFLOP. 95.2696%. Conv 0.0269508 GFLOP. 3.33857%. SpatialBN 0.00846444 GFLOP. 1.04855%. Mul 0.002561 GFLOP. 0.317248%. FC 0.000210112 GFLOP. 0.0260279%. Add 0.807256 GFLOP in Total Feature Memory Read per operator type: 58.5253 MB. 43.0891%. Mul 43.2015 MB. 31.807%. Conv 27.2869 MB. 20.0899%. SpatialBN 5.12912 MB. 3.77631%. FC 1.6809 MB. 1.23756%. Add 135.824 MB in Total Feature Memory Written per operator type: 33.8578 MB. 38.1965%. Mul 26.9881 MB. 30.4465%. Conv 26.9508 MB. 30.4044%. SpatialBN 0.840448 MB. 0.948147%. Add 0.004 MB. 0.00451258%. FC 88.6412 MB in Total Parameter Memory per operator type: 15.8248 MB. 74.9391%. Conv 5.124 MB. 24.265%. FC 0.168064 MB. 0.795877%. SpatialBN 0 MB. 0%. Add 0 MB. 0%. Mul 21.1168 MB in Total ``` ### Optimized ``` Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996 Time per operator type: 29.776 ms. 65.002%. Conv 12.2803 ms. 26.8084%. Sigmoid 3.15073 ms. 6.87815%. Mul 0.328651 ms. 0.717456%. AveragePool 0.186237 ms. 0.406563%. FC 0.0832429 ms. 0.181722%. Add 0.0026184 ms. 0.00571606%. Squeeze 45.8078 ms in Total FLOP per operator type: 0.76907 GFLOP. 98.5601%. Conv 0.00846444 GFLOP. 1.08476%. Mul 0.002561 GFLOP. 0.328205%. FC 0.000210112 GFLOP. 0.0269269%. Add 0.780305 GFLOP in Total Feature Memory Read per operator type: 58.5253 MB. 53.8803%. Mul 43.2855 MB. 39.8501%. Conv 5.12912 MB. 4.72204%. FC 1.6809 MB. 1.54749%. Add 108.621 MB in Total Feature Memory Written per operator type: 33.8578 MB. 54.8834%. Mul 26.9881 MB. 43.7477%. Conv 0.840448 MB. 1.36237%. Add 0.004 MB. 0.00648399%. FC 61.6904 MB in Total Parameter Memory per operator type: 15.8248 MB. 75.5403%. Conv 5.124 MB. 24.4597%. FC 0 MB. 0%. Add 0 MB. 0%. Mul 20.9488 MB in Total ``` ## EfficientNet-B1 ### Optimized ``` Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256 Time per operator type: 45.7915 ms. 66.3206%. Conv 17.8718 ms. 25.8841%. Sigmoid 4.44132 ms. 6.43244%. Mul 0.51001 ms. 0.738658%. AveragePool 0.233283 ms. 0.337868%. Add 0.194986 ms. 0.282402%. FC 0.00268255 ms. 0.00388519%. Squeeze 69.0456 ms in Total FLOP per operator type: 1.37105 GFLOP. 98.7673%. Conv 0.0138759 GFLOP. 0.99959%. Mul 0.002561 GFLOP. 0.184489%. FC 0.000674432 GFLOP. 0.0485847%. Add 1.38816 GFLOP in Total Feature Memory Read per operator type: 94.624 MB. 54.0789%. Mul 69.8255 MB. 39.9062%. Conv 5.39546 MB. 3.08357%. Add 5.12912 MB. 2.93136%. FC 174.974 MB in Total Feature Memory Written per operator type: 55.5035 MB. 54.555%. Mul 43.5333 MB. 42.7894%. Conv 2.69773 MB. 2.65163%. Add 0.004 MB. 0.00393165%. FC 101.739 MB in Total Parameter Memory per operator type: 25.7479 MB. 83.4024%. Conv 5.124 MB. 16.5976%. FC 0 MB. 0%. Add 0 MB. 0%. Mul 30.8719 MB in Total ``` ## EfficientNet-B2 ### Optimized ``` Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366 Time per operator type: 61.4627 ms. 67.5845%. Conv 22.7458 ms. 25.0113%. Sigmoid 5.59931 ms. 6.15701%. Mul 0.642567 ms. 0.706568%. AveragePool 0.272795 ms. 0.299965%. Add 0.216178 ms. 0.237709%. FC 0.00268895 ms. 0.00295677%. Squeeze 90.942 ms in Total FLOP per operator type: 1.98431 GFLOP. 98.9343%. Conv 0.0177039 GFLOP. 0.882686%. Mul 0.002817 GFLOP. 0.140451%. FC 0.000853984 GFLOP. 0.0425782%. Add 2.00568 GFLOP in Total Feature Memory Read per operator type: 120.609 MB. 54.9637%. Mul 86.3512 MB. 39.3519%. Conv 6.83187 MB. 3.11341%. Add 5.64163 MB. 2.571%. FC 219.433 MB in Total Feature Memory Written per operator type: 70.8155 MB. 54.6573%. Mul 55.3273 MB. 42.7031%. Conv 3.41594 MB. 2.63651%. Add 0.004 MB. 0.00308731%. FC 129.563 MB in Total Parameter Memory per operator type: 30.4721 MB. 84.3913%. Conv 5.636 MB. 15.6087%. FC 0 MB. 0%. Add 0 MB. 0%. Mul 36.1081 MB in Total ``` ## MixNet-M ### Optimized ``` Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448 Time per operator type: 48.1139 ms. 75.2052%. Conv 7.1341 ms. 11.1511%. Sigmoid 2.63706 ms. 4.12189%. SpatialBN 1.73186 ms. 2.70701%. Mul 1.38707 ms. 2.16809%. Split 1.29322 ms. 2.02139%. Concat 1.00093 ms. 1.56452%. Relu 0.235309 ms. 0.367803%. Add 0.221579 ms. 0.346343%. FC 0.219315 ms. 0.342803%. AveragePool 0.00250145 ms. 0.00390993%. Squeeze 63.9768 ms in Total FLOP per operator type: 0.675273 GFLOP. 95.5827%. Conv 0.0221072 GFLOP. 3.12921%. SpatialBN 0.00538445 GFLOP. 0.762152%. Mul 0.003073 GFLOP. 0.434973%. FC 0.000642488 GFLOP. 0.0909421%. Add 0 GFLOP. 0%. Concat 0 GFLOP. 0%. Relu 0.70648 GFLOP in Total Feature Memory Read per operator type: 46.8424 MB. 30.502%. Conv 36.8626 MB. 24.0036%. Mul 22.3152 MB. 14.5309%. SpatialBN 22.1074 MB. 14.3955%. Concat 14.1496 MB. 9.21372%. Relu 6.15414 MB. 4.00735%. FC 5.1399 MB. 3.34692%. Add 153.571 MB in Total Feature Memory Written per operator type: 32.7672 MB. 28.4331%. Conv 22.1072 MB. 19.1831%. Concat 22.1072 MB. 19.1831%. SpatialBN 21.5378 MB. 18.689%. Mul 14.1496 MB. 12.2781%. Relu 2.56995 MB. 2.23003%. Add 0.004 MB. 0.00347092%. FC 115.243 MB in Total Parameter Memory per operator type: 13.7059 MB. 68.674%. Conv 6.148 MB. 30.8049%. FC 0.104 MB. 0.521097%. SpatialBN 0 MB. 0%. Add 0 MB. 0%. Concat 0 MB. 0%. Mul 0 MB. 0%. Relu 19.9579 MB in Total ``` ## TF MobileNet-V3 Large 1.0 ### Optimized ``` Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525 Time per operator type: 17.437 ms. 80.0087%. Conv 1.27662 ms. 5.8577%. Add 1.12759 ms. 5.17387%. Div 0.701155 ms. 3.21721%. Mul 0.562654 ms. 2.58171%. Relu 0.431144 ms. 1.97828%. Clip 0.156902 ms. 0.719936%. FC 0.0996858 ms. 0.457402%. AveragePool 0.00112455 ms. 0.00515993%. Flatten 21.7939 ms in Total FLOP per operator type: 0.43062 GFLOP. 98.1484%. Conv 0.002561 GFLOP. 0.583713%. FC 0.00210867 GFLOP. 0.480616%. Mul 0.00193868 GFLOP. 0.441871%. Add 0.00151532 GFLOP. 0.345377%. Div 0 GFLOP. 0%. Relu 0.438743 GFLOP in Total Feature Memory Read per operator type: 34.7967 MB. 43.9391%. Conv 14.496 MB. 18.3046%. Mul 9.44828 MB. 11.9307%. Add 9.26157 MB. 11.6949%. Relu 6.0614 MB. 7.65395%. Div 5.12912 MB. 6.47673%. FC 79.193 MB in Total Feature Memory Written per operator type: 17.6247 MB. 35.8656%. Conv 9.26157 MB. 18.847%. Relu 8.43469 MB. 17.1643%. Mul 7.75472 MB. 15.7806%. Add 6.06128 MB. 12.3345%. Div 0.004 MB. 0.00813985%. FC 49.1409 MB in Total Parameter Memory per operator type: 16.6851 MB. 76.5052%. Conv 5.124 MB. 23.4948%. FC 0 MB. 0%. Add 0 MB. 0%. Div 0 MB. 0%. Mul 0 MB. 0%. Relu 21.8091 MB in Total ``` ## MobileNet-V3 (RW) ### Unoptimized ``` Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712 Time per operator type: 15.9266 ms. 69.2624%. Conv 2.36551 ms. 10.2873%. SpatialBN 1.39102 ms. 6.04936%. Add 1.30327 ms. 5.66773%. Div 0.737014 ms. 3.20517%. Mul 0.639697 ms. 2.78195%. Relu 0.375681 ms. 1.63378%. Clip 0.153126 ms. 0.665921%. FC 0.0993787 ms. 0.432184%. AveragePool 0.0032632 ms. 0.0141912%. Squeeze 22.9946 ms in Total FLOP per operator type: 0.430616 GFLOP. 94.4041%. Conv 0.0175992 GFLOP. 3.85829%. SpatialBN 0.002561 GFLOP. 0.561449%. FC 0.00210961 GFLOP. 0.46249%. Mul 0.00173891 GFLOP. 0.381223%. Add 0.00151626 GFLOP. 0.33241%. Div 0 GFLOP. 0%. Relu 0.456141 GFLOP in Total Feature Memory Read per operator type: 34.7354 MB. 36.4363%. Conv 17.7944 MB. 18.6658%. SpatialBN 14.5035 MB. 15.2137%. Mul 9.25778 MB. 9.71113%. Relu 7.84641 MB. 8.23064%. Add 6.06516 MB. 6.36216%. Div 5.12912 MB. 5.38029%. FC 95.3317 MB in Total Feature Memory Written per operator type: 17.6246 MB. 26.7264%. Conv 17.5992 MB. 26.6878%. SpatialBN 9.25778 MB. 14.0387%. Relu 8.43843 MB. 12.7962%. Mul 6.95565 MB. 10.5477%. Add 6.06502 MB. 9.19713%. Div 0.004 MB. 0.00606568%. FC 65.9447 MB in Total Parameter Memory per operator type: 16.6778 MB. 76.1564%. Conv 5.124 MB. 23.3979%. FC 0.0976 MB. 0.445674%. SpatialBN 0 MB. 0%. Add 0 MB. 0%. Div 0 MB. 0%. Mul 0 MB. 0%. Relu 21.8994 MB in Total ``` ### Optimized ``` Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527 Time per operator type: 17.146 ms. 78.8965%. Conv 1.38453 ms. 6.37084%. Add 1.30991 ms. 6.02749%. Div 0.685417 ms. 3.15391%. Mul 0.532589 ms. 2.45068%. Relu 0.418263 ms. 1.92461%. Clip 0.15128 ms. 0.696106%. FC 0.102065 ms. 0.469648%. AveragePool 0.0022143 ms. 0.010189%. Squeeze 21.7323 ms in Total FLOP per operator type: 0.430616 GFLOP. 98.1927%. Conv 0.002561 GFLOP. 0.583981%. FC 0.00210961 GFLOP. 0.481051%. Mul 0.00173891 GFLOP. 0.396522%. Add 0.00151626 GFLOP. 0.34575%. Div 0 GFLOP. 0%. Relu 0.438542 GFLOP in Total Feature Memory Read per operator type: 34.7842 MB. 44.833%. Conv 14.5035 MB. 18.6934%. Mul 9.25778 MB. 11.9323%. Relu 7.84641 MB. 10.1132%. Add 6.06516 MB. 7.81733%. Div 5.12912 MB. 6.61087%. FC 77.5861 MB in Total Feature Memory Written per operator type: 17.6246 MB. 36.4556%. Conv 9.25778 MB. 19.1492%. Relu 8.43843 MB. 17.4544%. Mul 6.95565 MB. 14.3874%. Add 6.06502 MB. 12.5452%. Div 0.004 MB. 0.00827378%. FC 48.3455 MB in Total Parameter Memory per operator type: 16.6778 MB. 76.4973%. Conv 5.124 MB. 23.5027%. FC 0 MB. 0%. Add 0 MB. 0%. Div 0 MB. 0%. Mul 0 MB. 0%. Relu 21.8018 MB in Total ``` ## MnasNet-A1 ### Unoptimized ``` Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345 Time per operator type: 24.4656 ms. 79.0905%. Conv 4.14958 ms. 13.4144%. SpatialBN 1.60598 ms. 5.19169%. Relu 0.295219 ms. 0.95436%. Mul 0.187609 ms. 0.606486%. FC 0.120556 ms. 0.389724%. AveragePool 0.09036 ms. 0.292109%. Add 0.015727 ms. 0.050841%. Sigmoid 0.00306205 ms. 0.00989875%. Squeeze 30.9337 ms in Total FLOP per operator type: 0.620598 GFLOP. 95.6434%. Conv 0.0248873 GFLOP. 3.8355%. SpatialBN 0.002561 GFLOP. 0.394688%. FC 0.000597408 GFLOP. 0.0920695%. Mul 0.000222656 GFLOP. 0.0343146%. Add 0 GFLOP. 0%. Relu 0.648867 GFLOP in Total Feature Memory Read per operator type: 35.5457 MB. 38.4109%. Conv 25.1552 MB. 27.1829%. SpatialBN 22.5235 MB. 24.339%. Relu 5.12912 MB. 5.54256%. FC 2.40586 MB. 2.59978%. Mul 1.78125 MB. 1.92483%. Add 92.5406 MB in Total Feature Memory Written per operator type: 24.9042 MB. 32.9424%. Conv 24.8873 MB. 32.92%. SpatialBN 22.5235 MB. 29.7932%. Relu 2.38963 MB. 3.16092%. Mul 0.890624 MB. 1.17809%. Add 0.004 MB. 0.00529106%. FC 75.5993 MB in Total Parameter Memory per operator type: 10.2732 MB. 66.1459%. Conv 5.124 MB. 32.9917%. FC 0.133952 MB. 0.86247%. SpatialBN 0 MB. 0%. Add 0 MB. 0%. Mul 0 MB. 0%. Relu 15.5312 MB in Total ``` ### Optimized ``` Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597 Time per operator type: 22.0547 ms. 91.1375%. Conv 1.49096 ms. 6.16116%. Relu 0.253417 ms. 1.0472%. Mul 0.18506 ms. 0.76473%. FC 0.112942 ms. 0.466717%. AveragePool 0.086769 ms. 0.358559%. Add 0.0127889 ms. 0.0528479%. Sigmoid 0.0027346 ms. 0.0113003%. Squeeze 24.1994 ms in Total FLOP per operator type: 0.620598 GFLOP. 99.4581%. Conv 0.002561 GFLOP. 0.41043%. FC 0.000597408 GFLOP. 0.0957417%. Mul 0.000222656 GFLOP. 0.0356832%. Add 0 GFLOP. 0%. Relu 0.623979 GFLOP in Total Feature Memory Read per operator type: 35.6127 MB. 52.7968%. Conv 22.5235 MB. 33.3917%. Relu 5.12912 MB. 7.60406%. FC 2.40586 MB. 3.56675%. Mul 1.78125 MB. 2.64075%. Add 67.4524 MB in Total Feature Memory Written per operator type: 24.9042 MB. 49.1092%. Conv 22.5235 MB. 44.4145%. Relu 2.38963 MB. 4.71216%. Mul 0.890624 MB. 1.75624%. Add 0.004 MB. 0.00788768%. FC 50.712 MB in Total Parameter Memory per operator type: 10.2732 MB. 66.7213%. Conv 5.124 MB. 33.2787%. FC 0 MB. 0%. Add 0 MB. 0%. Mul 0 MB. 0%. Relu 15.3972 MB in Total ``` ## MnasNet-B1 ### Unoptimized ``` Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322 Time per operator type: 29.1121 ms. 83.3081%. Conv 4.14959 ms. 11.8746%. SpatialBN 1.35823 ms. 3.88675%. Relu 0.186188 ms. 0.532802%. FC 0.116244 ms. 0.332647%. Add 0.018641 ms. 0.0533437%. AveragePool 0.0040904 ms. 0.0117052%. Squeeze 34.9451 ms in Total FLOP per operator type: 0.626272 GFLOP. 96.2088%. Conv 0.0218266 GFLOP. 3.35303%. SpatialBN 0.002561 GFLOP. 0.393424%. FC 0.000291648 GFLOP. 0.0448034%. Add 0 GFLOP. 0%. Relu 0.650951 GFLOP in Total Feature Memory Read per operator type: 34.4354 MB. 41.3788%. Conv 22.1299 MB. 26.5921%. SpatialBN 19.1923 MB. 23.0622%. Relu 5.12912 MB. 6.16333%. FC 2.33318 MB. 2.80364%. Add 83.2199 MB in Total Feature Memory Written per operator type: 21.8266 MB. 34.0955%. Conv 21.8266 MB. 34.0955%. SpatialBN 19.1923 MB. 29.9805%. Relu 1.16659 MB. 1.82234%. Add 0.004 MB. 0.00624844%. FC 64.016 MB in Total Parameter Memory per operator type: 12.2576 MB. 69.9104%. Conv 5.124 MB. 29.2245%. FC 0.15168 MB. 0.865099%. SpatialBN 0 MB. 0%. Add 0 MB. 0%. Relu 17.5332 MB in Total ``` ### Optimized ``` Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426 Time per operator type: 24.9888 ms. 94.0962%. Conv 1.26147 ms. 4.75011%. Relu 0.176234 ms. 0.663619%. FC 0.113309 ms. 0.426672%. Add 0.0138708 ms. 0.0522311%. AveragePool 0.00295685 ms. 0.0111341%. Squeeze 26.5566 ms in Total FLOP per operator type: 0.626272 GFLOP. 99.5466%. Conv 0.002561 GFLOP. 0.407074%. FC 0.000291648 GFLOP. 0.0463578%. Add 0 GFLOP. 0%. Relu 0.629124 GFLOP in Total Feature Memory Read per operator type: 34.5112 MB. 56.4224%. Conv 19.1923 MB. 31.3775%. Relu 5.12912 MB. 8.3856%. FC 2.33318 MB. 3.81452%. Add 61.1658 MB in Total Feature Memory Written per operator type: 21.8266 MB. 51.7346%. Conv 19.1923 MB. 45.4908%. Relu 1.16659 MB. 2.76513%. Add 0.004 MB. 0.00948104%. FC 42.1895 MB in Total Parameter Memory per operator type: 12.2576 MB. 70.5205%. Conv 5.124 MB. 29.4795%. FC 0 MB. 0%. Add 0 MB. 0%. Relu 17.3816 MB in Total ``` ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2020 Ross Wightman Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/README.md ================================================ # (Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search. All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py)) ## What's New ### Aug 19, 2020 * Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1) * Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1) * Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX * ONNX runtime based validation script added * activations (mostly) brought in sync with `timm` equivalents ### April 5, 2020 * Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite * 3.5M param MobileNet-V2 100 @ 73% * 4.5M param MobileNet-V2 110d @ 75% * 6.1M param MobileNet-V2 140 @ 76.5% * 5.8M param MobileNet-V2 120d @ 77.3% ### March 23, 2020 * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1 * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior ### Feb 12, 2020 * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization. * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) ### Jan 22, 2020 * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models) * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues ### Nov 22, 2019 * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights. ### Nov 15, 2019 * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine ### Oct 30, 2019 * Many of the models will now work with torch.jit.script, MixNet being the biggest exception * Improved interface for enabling torchscript or ONNX export compatible modes (via config) * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn * Activation factory to select best version of activation by name or override one globally * Add pretrained checkpoint load helper that handles input conv and classifier changes ### Oct 27, 2019 * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base * Switch activations and global pooling to modules * Add memory-efficient Swish/Mish impl * Add as_sequential() method to all models and allow as an argument in entrypoint fns * Move MobileNetV3 into own file since it has a different head * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here ## Models Implemented models include: * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252) * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946) * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971) * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) * MixNet (https://arxiv.org/abs/1907.09595) * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) * MobileNet-V3 (https://arxiv.org/abs/1905.02244) * FBNet-C (https://arxiv.org/abs/1812.03443) * Single-Path NAS (https://arxiv.org/abs/1904.02877) I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code. ## Pretrained I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models |Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop | |---|---|---|---|---|---|---|---| | efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 | | efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 | | mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 | | efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 | | mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 | | efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 | | mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 | | efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 | | efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 | | efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 | | mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 | | mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 | | mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 | | mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 | | mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 | | mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 | | efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 | | mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 | | fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 | | mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 | | mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 | | spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 | | mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 | More pretrained models to come... ## Ported Weights The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args. **IMPORTANT:** * Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std. * Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl. To run validation for tf_efficientnet_b5: `python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic` To run validation w/ TF preprocessing for tf_efficientnet_b5: `python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing` To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: `python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5` |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop | |---|---|---|---|---|---|---| | tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A | | tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 | | tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 | | tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A | | tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A | | tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A | | tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A | | tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A | | tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A | | tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A | | tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A | | tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A | | tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 | | tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 | | tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A | | tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 | | tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A | | tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 | | tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A | | tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 | | tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 | | tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A | | tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A | | tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 | | tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A | | tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 | | tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A | | tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 | | tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A | | tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 | | tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A | | tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 | | tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 | | tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A | | tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A | | tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 | | tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | | tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 | | tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 | | tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | | tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 | | tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A | | tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A | | tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 | | tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 | | tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A | | tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A | | tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 | | tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A | | tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 | | tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 | | tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A | | tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A | | tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 | | tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 | | tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 | | tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A | | tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A | | tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A | | tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A | | tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 | | tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 | | tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 | | tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 | | tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 | | tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 | | tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 | | tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 | | tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A | | tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A | | tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 | | tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A | | tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A | | tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A | | tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 | | tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A | | tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 | | tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 | | tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A | | tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 | | tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A | | tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A | | tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 | | tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 | | tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A | | tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 | | tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A | | tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 | | tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A | | tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 | | tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A | | tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 | | tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A | | tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 | | tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A | | tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 | *tfp models validated with `tf-preprocessing` pipeline Google tf and tflite weights ported from official Tensorflow repositories * https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet * https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet * https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet ## Usage ### Environment All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x. Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself. PyTorch versions 1.4, 1.5, 1.6 have been tested with this code. I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: ``` conda create -n torch-env conda activate torch-env conda install -c pytorch pytorch torchvision cudatoolkit=10.2 ``` ### PyTorch Hub Models can be accessed via the PyTorch Hub API ``` >>> torch.hub.list('rwightman/gen-efficientnet-pytorch') ['efficientnet_b0', ...] >>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True) >>> model.eval() >>> output = model(torch.randn(1,3,224,224)) ``` ### Pip This package can be installed via pip. Install (after conda env/install): ``` pip install geffnet ``` Eval use: ``` >>> import geffnet >>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True) >>> m.eval() ``` Train use: ``` >>> import geffnet >>> # models can also be created by using the entrypoint directly >>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2) >>> m.train() ``` Create in a nn.Sequential container, for fast.ai, etc: ``` >>> import geffnet >>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True) ``` ### Exporting Scripts are included to * export models to ONNX (`onnx_export.py`) * optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg) * validate with ONNX runtime (`onnx_validate.py`) * convert ONNX model to Caffe2 (`onnx_to_caffe.py`) * validate in Caffe2 (`caffe2_validate.py`) * benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`) As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation: ``` python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx ``` These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible export now requires additional args mentioned in the export script (not needed in earlier versions). #### Export Notes 1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script. 2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working. 3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization. 3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here. ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/__init__.py ================================================ ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py ================================================ from .gen_efficientnet import * from .mobilenetv3 import * from .model_factory import create_model from .config import is_exportable, is_scriptable, set_exportable, set_scriptable from .activations import * ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py ================================================ from geffnet import config from geffnet.activations.activations_me import * from geffnet.activations.activations_jit import * from geffnet.activations.activations import * import torch _has_silu = 'silu' in dir(torch.nn.functional) _ACT_FN_DEFAULT = dict( silu=F.silu if _has_silu else swish, swish=F.silu if _has_silu else swish, mish=mish, relu=F.relu, relu6=F.relu6, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=hard_sigmoid, hard_swish=hard_swish, ) _ACT_FN_JIT = dict( silu=F.silu if _has_silu else swish_jit, swish=F.silu if _has_silu else swish_jit, mish=mish_jit, ) _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, mish=mish_me, hard_swish=hard_swish_me, hard_sigmoid_jit=hard_sigmoid_me, ) _ACT_LAYER_DEFAULT = dict( silu=nn.SiLU if _has_silu else Swish, swish=nn.SiLU if _has_silu else Swish, mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=HardSigmoid, hard_swish=HardSwish, ) _ACT_LAYER_JIT = dict( silu=nn.SiLU if _has_silu else SwishJit, swish=nn.SiLU if _has_silu else SwishJit, mish=MishJit, ) _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, mish=MishMe, hard_swish=HardSwishMe, hard_sigmoid=HardSigmoidMe ) _OVERRIDE_FN = {} _OVERRIDE_LAYER = {} def add_override_act_fn(name, fn): global _OVERRIDE_FN _OVERRIDE_FN[name] = fn def update_override_act_fn(overrides): assert isinstance(overrides, dict) global _OVERRIDE_FN _OVERRIDE_FN.update(overrides) def clear_override_act_fn(): global _OVERRIDE_FN _OVERRIDE_FN = {} def add_override_act_layer(name, fn): _OVERRIDE_LAYER[name] = fn def update_override_act_layer(overrides): assert isinstance(overrides, dict) global _OVERRIDE_LAYER _OVERRIDE_LAYER.update(overrides) def clear_override_act_layer(): global _OVERRIDE_LAYER _OVERRIDE_LAYER = {} def get_act_fn(name='relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_FN: return _OVERRIDE_FN[name] use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) if use_me and name in _ACT_FN_ME: # If not exporting or scripting the model, first look for a memory optimized version # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin return _ACT_FN_ME[name] if config.is_exportable() and name in ('silu', 'swish'): # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack return swish use_jit = not (config.is_exportable() or config.is_no_jit()) # NOTE: export tracing should work with jit scripted components, but I keep running into issues if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] def get_act_layer(name='relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_LAYER: return _OVERRIDE_LAYER[name] use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) if use_me and name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] if config.is_exportable() and name in ('silu', 'swish'): # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack return Swish use_jit = not (config.is_exportable() or config.is_no_jit()) # NOTE: export tracing should work with jit scripted components, but I keep running into issues if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name] ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py ================================================ """ Activations A collection of activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. Copyright 2020 Ross Wightman """ from torch import nn as nn from torch.nn import functional as F def swish(x, inplace: bool = False): """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) and also as Swish (https://arxiv.org/abs/1710.05941). """ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) class Swish(nn.Module): def __init__(self, inplace: bool = False): super(Swish, self).__init__() self.inplace = inplace def forward(self, x): return swish(x, self.inplace) def mish(x, inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ return x.mul(F.softplus(x).tanh()) class Mish(nn.Module): def __init__(self, inplace: bool = False): super(Mish, self).__init__() self.inplace = inplace def forward(self, x): return mish(x, self.inplace) def sigmoid(x, inplace: bool = False): return x.sigmoid_() if inplace else x.sigmoid() # PyTorch has this, but not with a consistent inplace argmument interface class Sigmoid(nn.Module): def __init__(self, inplace: bool = False): super(Sigmoid, self).__init__() self.inplace = inplace def forward(self, x): return x.sigmoid_() if self.inplace else x.sigmoid() def tanh(x, inplace: bool = False): return x.tanh_() if inplace else x.tanh() # PyTorch has this, but not with a consistent inplace argmument interface class Tanh(nn.Module): def __init__(self, inplace: bool = False): super(Tanh, self).__init__() self.inplace = inplace def forward(self, x): return x.tanh_() if self.inplace else x.tanh() def hard_swish(x, inplace: bool = False): inner = F.relu6(x + 3.).div_(6.) return x.mul_(inner) if inplace else x.mul(inner) class HardSwish(nn.Module): def __init__(self, inplace: bool = False): super(HardSwish, self).__init__() self.inplace = inplace def forward(self, x): return hard_swish(x, self.inplace) def hard_sigmoid(x, inplace: bool = False): if inplace: return x.add_(3.).clamp_(0., 6.).div_(6.) else: return F.relu6(x + 3.) / 6. class HardSigmoid(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoid, self).__init__() self.inplace = inplace def forward(self, x): return hard_sigmoid(x, self.inplace) ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py ================================================ """ Activations (jit) A collection of jit-scripted activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted versions if they contain in-place ops. Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from torch.nn import functional as F __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] @torch.jit.script def swish_jit(x, inplace: bool = False): """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) and also as Swish (https://arxiv.org/abs/1710.05941). """ return x.mul(x.sigmoid()) @torch.jit.script def mish_jit(x, _inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ return x.mul(F.softplus(x).tanh()) class SwishJit(nn.Module): def __init__(self, inplace: bool = False): super(SwishJit, self).__init__() def forward(self, x): return swish_jit(x) class MishJit(nn.Module): def __init__(self, inplace: bool = False): super(MishJit, self).__init__() def forward(self, x): return mish_jit(x) @torch.jit.script def hard_sigmoid_jit(x, inplace: bool = False): # return F.relu6(x + 3.) / 6. return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? class HardSigmoidJit(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoidJit, self).__init__() def forward(self, x): return hard_sigmoid_jit(x) @torch.jit.script def hard_swish_jit(x, inplace: bool = False): # return x * (F.relu6(x + 3.) / 6) return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? class HardSwishJit(nn.Module): def __init__(self, inplace: bool = False): super(HardSwishJit, self).__init__() def forward(self, x): return hard_swish_jit(x) ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py ================================================ """ Activations (memory-efficient w/ custom autograd) A collection of activations fn and modules with a common interface so that they can easily be swapped. All have an `inplace` arg even if not used. These activations are not compatible with jit scripting or ONNX export of the model, please use either the JIT or basic versions of the activations. Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from torch.nn import functional as F __all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] @torch.jit.script def swish_jit_fwd(x): return x.mul(torch.sigmoid(x)) @torch.jit.script def swish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) class SwishJitAutoFn(torch.autograd.Function): """ torch.jit.script optimised Swish w/ memory-efficient checkpoint Inspired by conversation btw Jeremy Howard & Adam Pazske https://twitter.com/jeremyphoward/status/1188251041835315200 Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) and also as Swish (https://arxiv.org/abs/1710.05941). """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return swish_jit_bwd(x, grad_output) def swish_me(x, inplace=False): return SwishJitAutoFn.apply(x) class SwishMe(nn.Module): def __init__(self, inplace: bool = False): super(SwishMe, self).__init__() def forward(self, x): return SwishJitAutoFn.apply(x) @torch.jit.script def mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x))) @torch.jit.script def mish_jit_bwd(x, grad_output): x_sigmoid = torch.sigmoid(x) x_tanh_sp = F.softplus(x).tanh() return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) class MishJitAutoFn(torch.autograd.Function): """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 A memory efficient, jit scripted variant of Mish """ @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return mish_jit_bwd(x, grad_output) def mish_me(x, inplace=False): return MishJitAutoFn.apply(x) class MishMe(nn.Module): def __init__(self, inplace: bool = False): super(MishMe, self).__init__() def forward(self, x): return MishJitAutoFn.apply(x) @torch.jit.script def hard_sigmoid_jit_fwd(x, inplace: bool = False): return (x + 3).clamp(min=0, max=6).div(6.) @torch.jit.script def hard_sigmoid_jit_bwd(x, grad_output): m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. return grad_output * m class HardSigmoidJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return hard_sigmoid_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return hard_sigmoid_jit_bwd(x, grad_output) def hard_sigmoid_me(x, inplace: bool = False): return HardSigmoidJitAutoFn.apply(x) class HardSigmoidMe(nn.Module): def __init__(self, inplace: bool = False): super(HardSigmoidMe, self).__init__() def forward(self, x): return HardSigmoidJitAutoFn.apply(x) @torch.jit.script def hard_swish_jit_fwd(x): return x * (x + 3).clamp(min=0, max=6).div(6.) @torch.jit.script def hard_swish_jit_bwd(x, grad_output): m = torch.ones_like(x) * (x >= 3.) m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) return grad_output * m class HardSwishJitAutoFn(torch.autograd.Function): """A memory efficient, jit-scripted HardSwish activation""" @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return hard_swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] return hard_swish_jit_bwd(x, grad_output) def hard_swish_me(x, inplace=False): return HardSwishJitAutoFn.apply(x) class HardSwishMe(nn.Module): def __init__(self, inplace: bool = False): super(HardSwishMe, self).__init__() def forward(self, x): return HardSwishJitAutoFn.apply(x) ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py ================================================ """ Global layer config state """ from typing import Any, Optional __all__ = [ 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' ] # Set to True if prefer to have layers with no jit optimization (includes activations) _NO_JIT = False # Set to True if prefer to have activation layers with no jit optimization # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying # the jit flags so far are activations. This will change as more layers are updated and/or added. _NO_ACTIVATION_JIT = False # Set to True if exporting a model with Same padding via ONNX _EXPORTABLE = False # Set to True if wanting to use torch.jit.script on a model _SCRIPTABLE = False def is_no_jit(): return _NO_JIT class set_no_jit: def __init__(self, mode: bool) -> None: global _NO_JIT self.prev = _NO_JIT _NO_JIT = mode def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _NO_JIT _NO_JIT = self.prev return False def is_exportable(): return _EXPORTABLE class set_exportable: def __init__(self, mode: bool) -> None: global _EXPORTABLE self.prev = _EXPORTABLE _EXPORTABLE = mode def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _EXPORTABLE _EXPORTABLE = self.prev return False def is_scriptable(): return _SCRIPTABLE class set_scriptable: def __init__(self, mode: bool) -> None: global _SCRIPTABLE self.prev = _SCRIPTABLE _SCRIPTABLE = mode def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _SCRIPTABLE _SCRIPTABLE = self.prev return False class set_layer_config: """ Layer config context manager that allows setting all layer config flags at once. If a flag arg is None, it will not change the current value. """ def __init__( self, scriptable: Optional[bool] = None, exportable: Optional[bool] = None, no_jit: Optional[bool] = None, no_activation_jit: Optional[bool] = None): global _SCRIPTABLE global _EXPORTABLE global _NO_JIT global _NO_ACTIVATION_JIT self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT if scriptable is not None: _SCRIPTABLE = scriptable if exportable is not None: _EXPORTABLE = exportable if no_jit is not None: _NO_JIT = no_jit if no_activation_jit is not None: _NO_ACTIVATION_JIT = no_activation_jit def __enter__(self) -> None: pass def __exit__(self, *args: Any) -> bool: global _SCRIPTABLE global _EXPORTABLE global _NO_JIT global _NO_ACTIVATION_JIT _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev return False def layer_config_kwargs(kwargs): """ Consume config kwargs and return contextmgr obj """ return set_layer_config( scriptable=kwargs.pop('scriptable', None), exportable=kwargs.pop('exportable', None), no_jit=kwargs.pop('no_jit', None)) ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py ================================================ """ Conv2D w/ SAME padding, CondConv, MixedConv A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and MobileNetV3 models that maintain weight compatibility with original Tensorflow models. Copyright 2020 Ross Wightman """ import collections.abc import math from functools import partial from itertools import repeat from typing import Tuple, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .config import * # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) _quadruple = _ntuple(4) def _is_static_pad(kernel_size, stride=1, dilation=1, **_): return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 def _get_padding(kernel_size, stride=1, dilation=1, **_): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding def _calc_same_pad(i: int, k: int, s: int, d: int): return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) def _same_pad_arg(input_size, kernel_size, stride, dilation): ih, iw = input_size kh, kw = kernel_size pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] def _split_channels(num_chan, num_groups): split = [num_chan // num_groups for _ in range(num_groups)] split[0] += num_chan - sum(split) return split def conv2d_same( x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): ih, iw = x.size()[-2:] kh, kw = weight.size()[-2:] pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) class Conv2dSame(nn.Conv2d): """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions """ # pylint: disable=unused-argument def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2dSame, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) def forward(self, x): return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class Conv2dSameExport(nn.Conv2d): """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions NOTE: This does not currently work with torch.jit.script """ # pylint: disable=unused-argument def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2dSameExport, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) self.pad = None self.pad_input_size = (0, 0) def forward(self, x): input_size = x.size()[-2:] if self.pad is None: pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) self.pad = nn.ZeroPad2d(pad_arg) self.pad_input_size = input_size if self.pad is not None: x = self.pad(x) return F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def get_padding_value(padding, kernel_size, **kwargs): dynamic = False if isinstance(padding, str): # for any string padding, the padding will be calculated for you, one of three ways padding = padding.lower() if padding == 'same': # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact if _is_static_pad(kernel_size, **kwargs): # static case, no extra overhead padding = _get_padding(kernel_size, **kwargs) else: # dynamic padding padding = 0 dynamic = True elif padding == 'valid': # 'VALID' padding, same as padding=0 padding = 0 else: # Default to PyTorch style 'same'-ish symmetric padding padding = _get_padding(kernel_size, **kwargs) return padding, dynamic def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): padding = kwargs.pop('padding', '') kwargs.setdefault('bias', False) padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) if is_dynamic: if is_exportable(): assert not is_scriptable() return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) else: return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) else: return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] num_groups = len(kernel_size) in_splits = _split_channels(in_channels, num_groups) out_splits = _split_channels(out_channels, num_groups) self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): conv_groups = out_ch if depthwise else 1 self.add_module( str(idx), create_conv2d_pad( in_ch, out_ch, k, stride=stride, padding=padding, dilation=dilation, groups=conv_groups, **kwargs) ) self.splits = in_splits def forward(self, x): x_split = torch.split(x, self.splits, 1) x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] x = torch.cat(x_out, 1) return x def get_condconv_initializer(initializer, num_experts, expert_shape): def condconv_initializer(weight): """CondConv initializer function.""" num_params = np.prod(expert_shape) if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params): raise (ValueError( 'CondConv variables must have shape [num_experts, num_params]')) for i in range(num_experts): initializer(weight[i].view(expert_shape)) return condconv_initializer class CondConv2d(nn.Module): """ Conditional Convolution Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: https://github.com/pytorch/pytorch/issues/17983 """ __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): super(CondConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) padding_val, is_padding_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation) self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript self.padding = _pair(padding_val) self.dilation = _pair(dilation) self.groups = groups self.num_experts = num_experts self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size weight_num_param = 1 for wd in self.weight_shape: weight_num_param *= wd self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) if bias: self.bias_shape = (self.out_channels,) self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init_weight = get_condconv_initializer( partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) init_weight(self.weight) if self.bias is not None: fan_in = np.prod(self.weight_shape[1:]) bound = 1 / math.sqrt(fan_in) init_bias = get_condconv_initializer( partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) init_bias(self.bias) def forward(self, x, routing_weights): B, C, H, W = x.shape weight = torch.matmul(routing_weights, self.weight) new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size weight = weight.view(new_weight_shape) bias = None if self.bias is not None: bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel x = x.view(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) else: out = F.conv2d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) # Literal port (from TF definition) # x = torch.split(x, 1, 0) # weight = torch.split(weight, 1, 0) # if self.bias is not None: # bias = torch.matmul(routing_weights, self.bias) # bias = torch.split(bias, 1, 0) # else: # bias = [None] * B # out = [] # for xi, wi, bi in zip(x, weight, bias): # wi = wi.view(*self.weight_shape) # if bi is not None: # bi = bi.view(*self.bias_shape) # out.append(self.conv_fn( # xi, wi, bi, stride=self.stride, padding=self.padding, # dilation=self.dilation, groups=self.groups)) # out = torch.cat(out, 0) return out def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'groups' not in kwargs # only use 'depthwise' bool arg if isinstance(kernel_size, list): assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently # We're going to use only lists for defining the MixedConv2d kernel groups, # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) else: depthwise = kwargs.pop('depthwise', False) groups = out_chs if depthwise else 1 if 'num_experts' in kwargs and kwargs['num_experts'] > 0: m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) else: m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) return m ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py ================================================ """ EfficientNet / MobileNetV3 Blocks and Builder Copyright 2020 Ross Wightman """ import re from copy import deepcopy from .conv2d_layers import * from geffnet.activations import * __all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' ] # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) # NOTE: momentum varies btw .99 and .9997 depending on source # .99 in official TF TPU impl # .9997 (/w .999 in search space) for paper # # PyTorch defaults are momentum = .1, eps = 1e-5 # BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 BN_EPS_TF_DEFAULT = 1e-3 _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) def get_bn_args_tf(): return _BN_ARGS_TF.copy() def resolve_bn_args(kwargs): bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} bn_momentum = kwargs.pop('bn_momentum', None) if bn_momentum is not None: bn_args['momentum'] = bn_momentum bn_eps = kwargs.pop('bn_eps', None) if bn_eps is not None: bn_args['eps'] = bn_eps return bn_args _SE_ARGS_DEFAULT = dict( gate_fn=sigmoid, act_layer=None, # None == use containing block's activation layer reduce_mid=False, divisor=1) def resolve_se_args(kwargs, in_chs, act_layer=None): se_kwargs = kwargs.copy() if kwargs is not None else {} # fill in args that aren't specified with the defaults for k, v in _SE_ARGS_DEFAULT.items(): se_kwargs.setdefault(k, v) # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch if not se_kwargs.pop('reduce_mid'): se_kwargs['reduced_base_chs'] = in_chs # act_layer override, if it remains None, the containing block's act_layer will be used if se_kwargs['act_layer'] is None: assert act_layer is not None se_kwargs['act_layer'] = act_layer return se_kwargs def resolve_act_layer(kwargs, default='relu'): act_layer = kwargs.pop('act_layer', default) if isinstance(act_layer, str): act_layer = get_act_layer(act_layer) return act_layer def make_divisible(v: int, divisor: int = 8, min_value: int = None): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. new_v += divisor return new_v def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): """Round number of filters based on depth multiplier.""" if not multiplier: return channels channels *= multiplier return make_divisible(channels, divisor, channel_min) def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): """Apply drop connect.""" if not training: return inputs keep_prob = 1 - drop_connect_rate random_tensor = keep_prob + torch.rand( (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) random_tensor.floor_() # binarize output = inputs.div(keep_prob) * random_tensor return output class SqueezeExcite(nn.Module): def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): super(SqueezeExcite, self).__init__() reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.act1 = act_layer(inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) self.gate_fn = gate_fn def forward(self, x): x_se = x.mean((2, 3), keepdim=True) x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) x = x * self.gate_fn(x_se) return x class ConvBnAct(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(ConvBnAct, self).__init__() assert stride in [1, 2] norm_kwargs = norm_kwargs or {} self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) self.bn1 = norm_layer(out_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn1(x) x = self.act1(x) return x class DepthwiseSeparableConv(nn.Module): """ DepthwiseSeparable block Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion factor of 1.0. This is an alternative to having a IR with optional first pw conv. """ def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() assert stride in [1, 2] norm_kwargs = norm_kwargs or {} self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.drop_connect_rate = drop_connect_rate self.conv_dw = select_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if se_ratio is not None and se_ratio > 0.: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = nn.Identity() self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() def forward(self, x): residual = x x = self.conv_dw(x) x = self.bn1(x) x = self.act1(x) x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) x = self.act2(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class InvertedResidual(nn.Module): """ Inverted residual block w/ optional SE""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, conv_kwargs=None, drop_connect_rate=0.): super(InvertedResidual, self).__init__() norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs: int = make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate # Point-wise expansion self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw = select_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) # Squeeze-and-excitation if se_ratio is not None and se_ratio > 0.: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = nn.Identity() # for jit.script compat # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x) x = self.bn2(x) x = self.act2(x) # Squeeze-and-excitation x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn3(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, num_experts=0, drop_connect_rate=0.): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) super(CondConvResidual, self).__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, drop_connect_rate=drop_connect_rate) self.routing_fn = nn.Linear(in_chs, self.num_experts) def forward(self, x): residual = x # CondConv routing pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) # Point-wise expansion x = self.conv_pw(x, routing_weights) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x, routing_weights) x = self.bn2(x) x = self.act2(x) # Squeeze-and-excitation x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x, routing_weights) x = self.bn3(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class EdgeResidual(nn.Module): """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(EdgeResidual, self).__init__() norm_kwargs = norm_kwargs or {} mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate # Expansion convolution self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation if se_ratio is not None and se_ratio > 0.: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = nn.Identity() # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) def forward(self, x): residual = x # Expansion convolution x = self.conv_exp(x) x = self.bn1(x) x = self.act1(x) # Squeeze-and-excitation x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn2(x) if self.has_residual: if self.drop_connect_rate > 0.: x = drop_connect(x, self.training, self.drop_connect_rate) x += residual return x class EfficientNetBuilder: """ Build Trunk Blocks for Efficient/Mobile Networks This ended up being somewhat of a cross between https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py and https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_layer=None, se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): self.channel_multiplier = channel_multiplier self.channel_divisor = channel_divisor self.channel_min = channel_min self.pad_type = pad_type self.act_layer = act_layer self.se_kwargs = se_kwargs self.norm_layer = norm_layer self.norm_kwargs = norm_kwargs self.drop_connect_rate = drop_connect_rate # updated during build self.in_chs = None self.block_idx = 0 self.block_count = 0 def _round_channels(self, chs): return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) def _make_block(self, ba): bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self._round_channels(ba['out_chs']) if 'fake_in_chs' in ba and ba['fake_in_chs']: # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) ba['norm_layer'] = self.norm_layer ba['norm_kwargs'] = self.norm_kwargs ba['pad_type'] = self.pad_type # block act fn overrides the model default ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer assert ba['act_layer'] is not None if bt == 'ir': ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count ba['se_kwargs'] = self.se_kwargs if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: block = InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count ba['se_kwargs'] = self.se_kwargs block = DepthwiseSeparableConv(**ba) elif bt == 'er': ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count ba['se_kwargs'] = self.se_kwargs block = EdgeResidual(**ba) elif bt == 'cn': block = ConvBnAct(**ba) else: raise AssertionError('Uknkown block type (%s) while building model.' % bt) self.in_chs = ba['out_chs'] # update in_chs for arg of next block return block def _make_stack(self, stack_args): blocks = [] # each stack (stage) contains a list of block arguments for i, ba in enumerate(stack_args): if i >= 1: # only the first block in any stack can have a stride > 1 ba['stride'] = 1 block = self._make_block(ba) blocks.append(block) self.block_idx += 1 # incr global idx (across all stacks) return nn.Sequential(*blocks) def __call__(self, in_chs, block_args): """ Build the blocks Args: in_chs: Number of input-channels passed to first block block_args: A list of lists, outer list defines stages, inner list contains strings defining block configuration(s) Return: List of block stacks (each stack wrapped in nn.Sequential) """ self.in_chs = in_chs self.block_count = sum([len(x) for x in block_args]) self.block_idx = 0 blocks = [] # outer list of block_args defines the stacks ('stages' by some conventions) for _stack_idx, stack in enumerate(block_args): assert isinstance(stack, list) stack = self._make_stack(stack) blocks.append(stack) return blocks def _parse_ksize(ss): if ss.isdigit(): return int(ss) else: return [int(k) for k in ss.split('.')] def _decode_block_str(block_str): """ Decode block definition string Gets a list of block arg (dicts) through a string notation of arguments. E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip All args can exist in any order with the exception of the leading string which is assumed to indicate the block type. leading string - block type ( ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) r - number of repeat blocks, k - kernel size, s - strides (1-9), e - expansion ratio, c - output channels, se - squeeze/excitation ratio n - activation fn ('re', 'r6', 'hs', or 'sw') Args: block_str: a string representation of block arguments. Returns: A list of block args (dicts) Raises: ValueError: if the string def not properly specified """ assert isinstance(block_str, str) ops = block_str.split('_') block_type = ops[0] # take the block type off the front ops = ops[1:] options = {} noskip = False for op in ops: # string options being checked on individual basis, combine if they grow if op == 'noskip': noskip = True elif op.startswith('n'): # activation fn key = op[0] v = op[1:] if v == 're': value = get_act_layer('relu') elif v == 'r6': value = get_act_layer('relu6') elif v == 'hs': value = get_act_layer('hard_swish') elif v == 'sw': value = get_act_layer('swish') else: continue options[key] = value else: # all numeric options splits = re.split(r'(\d.*)', op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # if act_layer is None, the model default (passed to model init) will be used act_layer = options['n'] if 'n' in options else None exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly if block_type == 'ir': block_args = dict( block_type=block_type, dw_kernel_size=_parse_ksize(options['k']), exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_layer=act_layer, noskip=noskip, ) if 'cc' in options: block_args['num_experts'] = int(options['cc']) elif block_type == 'ds' or block_type == 'dsa': block_args = dict( block_type=block_type, dw_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_layer=act_layer, pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or noskip, ) elif block_type == 'er': block_args = dict( block_type=block_type, exp_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), fake_in_chs=fake_in_chs, se_ratio=float(options['se']) if 'se' in options else None, stride=int(options['s']), act_layer=act_layer, noskip=noskip, ) elif block_type == 'cn': block_args = dict( block_type=block_type, kernel_size=int(options['k']), out_chs=int(options['c']), stride=int(options['s']), act_layer=act_layer, ) else: raise AssertionError('Unknown block type (%s)' % block_type) return block_args, num_repeat def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): """ Per-stage depth scaling Scales the block repeats in each stage. This depth scaling impl maintains compatibility with the EfficientNet scaling method, while allowing sensible scaling for other models that may have multiple block arg definitions in each stage. """ # We scale the total repeat count for each stage, there may be multiple # block arg defs per stage so we need to sum. num_repeat = sum(repeats) if depth_trunc == 'round': # Truncating to int by rounding allows stages with few repeats to remain # proportionally smaller for longer. This is a good choice when stage definitions # include single repeat stages that we'd prefer to keep that way as long as possible num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) else: # The default for EfficientNet truncates repeats to int via 'ceil'. # Any multiplier > 1.0 will result in an increased depth for every stage. num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) # Proportionally distribute repeat count scaling to each block definition in the stage. # Allocation is done in reverse as it results in the first block being less likely to be scaled. # The first block makes less sense to repeat in most of the arch definitions. repeats_scaled = [] for r in repeats[::-1]: rs = max(1, round((r / num_repeat * num_repeat_scaled))) repeats_scaled.append(rs) num_repeat -= r num_repeat_scaled -= rs repeats_scaled = repeats_scaled[::-1] # Apply the calculated scaling to each block arg in the stage sa_scaled = [] for ba, rep in zip(stack_args, repeats_scaled): sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) return sa_scaled def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): arch_args = [] for stack_idx, block_strings in enumerate(arch_def): assert isinstance(block_strings, list) stack_args = [] repeats = [] for block_str in block_strings: assert isinstance(block_str, str) ba, rep = _decode_block_str(block_str) if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: ba['num_experts'] *= experts_multiplier stack_args.append(ba) repeats.append(rep) if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) else: arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) return arch_args def initialize_weight_goog(m, n='', fix_group_fanout=True): # weight init as per Tensorflow Official impl # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups init_weight_fn = get_condconv_initializer( lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels if fix_group_fanout: fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): fan_out = m.weight.size(0) # fan-out fan_in = 0 if 'routing_fn' in n: fan_in = m.weight.size(1) init_range = 1.0 / math.sqrt(fan_in + fan_out) m.weight.data.uniform_(-init_range, init_range) m.bias.data.zero_() def initialize_weight_default(m, n=''): if isinstance(m, CondConv2d): init_fn = get_condconv_initializer(partial( nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) init_fn(m.weight) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py ================================================ """ Generic Efficient Networks A generic MobileNet class with building blocks to support a variety of models: * EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 * EfficientNet-Lite * MixNet (Small, Medium, and Large) - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 * MNasNet B1, A1 (SE), Small - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 * FBNet-C - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 * Single-Path NAS Pixel1 - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 * And likely more... Hacked together by / Copyright 2020 Ross Wightman """ import torch.nn as nn import torch.nn.functional as F from .config import layer_config_kwargs, is_scriptable from .conv2d_layers import select_conv2d from .helpers import load_pretrained from .efficientnet_builder import * __all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d', 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] model_urls = { 'mnasnet_050': None, 'mnasnet_075': None, 'mnasnet_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', 'mnasnet_140': None, 'mnasnet_small': None, 'semnasnet_050': None, 'semnasnet_075': None, 'semnasnet_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', 'semnasnet_140': None, 'mobilenetv2_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', 'mobilenetv2_110d': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', 'mobilenetv2_120d': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', 'mobilenetv2_140': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', 'fbnetc_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', 'spnasnet_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', 'efficientnet_b0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', 'efficientnet_b1': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', 'efficientnet_b2': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', 'efficientnet_b3': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', 'efficientnet_b4': None, 'efficientnet_b5': None, 'efficientnet_b6': None, 'efficientnet_b7': None, 'efficientnet_b8': None, 'efficientnet_l2': None, 'efficientnet_es': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', 'efficientnet_em': None, 'efficientnet_el': None, 'efficientnet_cc_b0_4e': None, 'efficientnet_cc_b0_8e': None, 'efficientnet_cc_b1_8e': None, 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', 'efficientnet_lite1': None, 'efficientnet_lite2': None, 'efficientnet_lite3': None, 'efficientnet_lite4': None, 'tf_efficientnet_b0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', 'tf_efficientnet_b1': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', 'tf_efficientnet_b2': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', 'tf_efficientnet_b3': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', 'tf_efficientnet_b4': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', 'tf_efficientnet_b5': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', 'tf_efficientnet_b6': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', 'tf_efficientnet_b7': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', 'tf_efficientnet_b8': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', 'tf_efficientnet_b0_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', 'tf_efficientnet_b1_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', 'tf_efficientnet_b2_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', 'tf_efficientnet_b3_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', 'tf_efficientnet_b4_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', 'tf_efficientnet_b5_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', 'tf_efficientnet_b6_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', 'tf_efficientnet_b7_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', 'tf_efficientnet_b8_ap': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', 'tf_efficientnet_b0_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', 'tf_efficientnet_b1_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', 'tf_efficientnet_b2_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', 'tf_efficientnet_b3_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', 'tf_efficientnet_b4_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', 'tf_efficientnet_b5_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', 'tf_efficientnet_b6_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', 'tf_efficientnet_b7_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', 'tf_efficientnet_l2_ns_475': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', 'tf_efficientnet_l2_ns': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', 'tf_efficientnet_es': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', 'tf_efficientnet_em': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', 'tf_efficientnet_el': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', 'tf_efficientnet_cc_b0_4e': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', 'tf_efficientnet_cc_b0_8e': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', 'tf_efficientnet_cc_b1_8e': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', 'tf_efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', 'tf_efficientnet_lite1': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', 'tf_efficientnet_lite2': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', 'tf_efficientnet_lite3': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', 'tf_efficientnet_lite4': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', 'tf_mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', 'tf_mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', 'tf_mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', } class GenEfficientNet(nn.Module): """ Generic EfficientNets An implementation of mobile optimized networks that covers: * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) * MixNet (Small, Medium, and Large, XL) * MNASNet A1, B1, and small * FBNet C * Single-Path NAS Pixel1 """ def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, channel_multiplier=1.0, channel_divisor=8, channel_min=None, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): super(GenEfficientNet, self).__init__() self.drop_rate = drop_rate if not fix_stem: stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) in_chs = stem_size builder = EfficientNetBuilder( channel_multiplier, channel_divisor, channel_min, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) self.bn2 = norm_layer(num_features, **norm_kwargs) self.act2 = act_layer(inplace=True) self.global_pool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(num_features, num_classes) for n, m in self.named_modules(): if weight_init == 'goog': initialize_weight_goog(m, n) else: initialize_weight_default(m, n) def features(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) x = self.blocks(x) x = self.conv_head(x) x = self.bn2(x) x = self.act2(x) return x def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) layers.extend([ self.conv_head, self.bn2, self.act2, self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) def forward(self, x): x = self.features(x) x = self.global_pool(x) x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) def _create_model(model_kwargs, variant, pretrained=False): as_sequential = model_kwargs.pop('as_sequential', False) model = GenEfficientNet(**model_kwargs) if pretrained: load_pretrained(model, model_urls[variant]) if as_sequential: model = model.as_sequential() return model def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-a1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet Paper: https://arxiv.org/pdf/1807.11626.pdf. Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_noskip'], # stage 1, 112x112 in ['ir_r2_k3_s2_e6_c24'], # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25'], # stage 3, 28x28 in ['ir_r4_k3_s2_e6_c80'], # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # stage 5, 14x14in ['ir_r3_k5_s2_e6_c160_se0.25'], # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-b1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet Paper: https://arxiv.org/pdf/1807.11626.pdf. Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_c16_noskip'], # stage 1, 112x112 in ['ir_r3_k3_s2_e3_c24'], # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40'], # stage 3, 28x28 in ['ir_r3_k5_s2_e6_c80'], # stage 4, 14x14in ['ir_r2_k3_s1_e6_c96'], # stage 5, 14x14in ['ir_r4_k5_s2_e6_c192'], # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a mnasnet-b1 model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet Paper: https://arxiv.org/pdf/1807.11626.pdf. Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ ['ds_r1_k3_s1_c8'], ['ir_r1_k3_s2_e3_c16'], ['ir_r2_k3_s2_e6_c16'], ['ir_r4_k5_s2_e6_c32_se0.25'], ['ir_r3_k3_s1_e6_c32_se0.25'], ['ir_r3_k5_s2_e6_c88_se0.25'], ['ir_r1_k3_s1_e6_c144'] ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=8, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mobilenet_v2( variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): """ Generate MobileNet-V2 network Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py Paper: https://arxiv.org/abs/1801.04381 """ arch_def = [ ['ds_r1_k3_s1_c16'], ['ir_r2_k3_s2_e6_c24'], ['ir_r3_k3_s2_e6_c32'], ['ir_r4_k3_s2_e6_c64'], ['ir_r3_k3_s1_e6_c96'], ['ir_r3_k3_s2_e6_c160'], ['ir_r1_k3_s1_e6_c320'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), stem_size=32, fix_stem=fix_stem_head, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), act_layer=nn.ReLU6, **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """ FBNet-C Paper: https://arxiv.org/abs/1812.03443 Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, it was used to confirm some building block details """ arch_def = [ ['ir_r1_k3_s1_e1_c16'], ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], ['ir_r4_k5_s2_e6_c184'], ['ir_r1_k3_s1_e6_c352'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=16, num_features=1984, # paper suggests this, but is not 100% clear channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates the Single-Path NAS model from search targeted for Pixel1 phone. Paper: https://arxiv.org/abs/1904.02877 Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_c16_noskip'], # stage 1, 112x112 in ['ir_r3_k3_s2_e3_c24'], # stage 2, 56x56 in ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], # stage 3, 28x28 in ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], # stage 4, 14x14in ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], # stage 5, 14x14in ['ir_r4_k5_s2_e6_c192'], # stage 6, 7x7 in ['ir_r1_k3_s1_e6_c320_noskip'] ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates an EfficientNet model. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py Paper: https://arxiv.org/abs/1905.11946 EfficientNet params name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 'efficientnet-b8': (2.2, 3.6, 672, 0.5), Args: channel_multiplier: multiplier to number of channels per layer depth_multiplier: multiplier to number of repeats per stage """ arch_def = [ ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'swish'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): arch_def = [ # NOTE `fc` is present to override a mismatch between stem channels and in chs not # present in other models ['er_r1_k3_s1_e4_c24_fc24_noskip'], ['er_r2_k3_s2_e8_c32'], ['er_r4_k3_s2_e8_c48'], ['ir_r5_k5_s2_e8_c96'], ['ir_r4_k5_s1_e8_c144'], ['ir_r2_k5_s2_e8_c192'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet_condconv( variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): """Creates an efficientnet-condconv model.""" arch_def = [ ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'swish'), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates an EfficientNet-Lite model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite Paper: https://arxiv.org/abs/1905.11946 EfficientNet params name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), Args: channel_multiplier: multiplier to number of channels per layer depth_multiplier: multiplier to number of repeats per stage """ arch_def = [ ['ds_r1_k3_s1_e1_c16'], ['ir_r2_k3_s2_e6_c24'], ['ir_r2_k5_s2_e6_c40'], ['ir_r3_k3_s2_e6_c80'], ['ir_r3_k5_s1_e6_c112'], ['ir_r4_k5_s2_e6_c192'], ['ir_r1_k3_s1_e6_c320'], ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), num_features=1280, stem_size=32, fix_stem=True, channel_multiplier=channel_multiplier, act_layer=nn.ReLU6, norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Small model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet Paper: https://arxiv.org/abs/1907.09595 """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16'], # relu # stage 1, 112x112 in ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu # stage 2, 56x56 in ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish # stage 3, 28x28 in ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish # stage 4, 14x14in ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish # stage 5, 14x14in ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish # 7x7 ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=1536, stem_size=16, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """Creates a MixNet Medium-Large model. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet Paper: https://arxiv.org/abs/1907.09595 """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c24'], # relu # stage 1, 112x112 in ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu # stage 2, 56x56 in ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish # stage 3, 28x28 in ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish # stage 4, 14x14in ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish # stage 5, 14x14in ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish # 7x7 ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), num_features=1536, stem_size=24, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'relu'), norm_kwargs=resolve_bn_args(kwargs), **kwargs ) model = _create_model(model_kwargs, variant, pretrained) return model def mnasnet_050(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) return model def mnasnet_075(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.75. """ model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) return model def mnasnet_100(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model def mnasnet_b1(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.0. """ return mnasnet_100(pretrained, **kwargs) def mnasnet_140(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 1.4 """ model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) return model def semnasnet_050(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) return model def semnasnet_075(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) return model def semnasnet_100(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model def mnasnet_a1(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ return semnasnet_100(pretrained, **kwargs) def semnasnet_140(pretrained=False, **kwargs): """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) return model def mnasnet_small(pretrained=False, **kwargs): """ MNASNet Small, depth multiplier of 1.0. """ model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv2_100(pretrained=False, **kwargs): """ MobileNet V2 w/ 1.0 channel multiplier """ model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv2_140(pretrained=False, **kwargs): """ MobileNet V2 w/ 1.4 channel multiplier """ model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) return model def mobilenetv2_110d(pretrained=False, **kwargs): """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" model = _gen_mobilenet_v2( 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) return model def mobilenetv2_120d(pretrained=False, **kwargs): """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ model = _gen_mobilenet_v2( 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) return model def fbnetc_100(pretrained=False, **kwargs): """ FBNet-C """ if pretrained: # pretrained model trained with non-default BN epsilon kwargs['bn_eps'] = BN_EPS_TF_DEFAULT model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) return model def spnasnet_100(pretrained=False, **kwargs): """ Single-Path NAS Pixel1""" model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model def efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 """ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1 """ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2 """ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def efficientnet_b3(pretrained=False, **kwargs): """ EfficientNet-B3 """ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4 """ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5 """ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6 """ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7 """ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def efficientnet_b8(pretrained=False, **kwargs): """ EfficientNet-B8 """ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 model = _gen_efficientnet( 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) return model def efficientnet_l2(pretrained=False, **kwargs): """ EfficientNet-L2. """ # NOTE for train, drop_rate should be 0.5 model = _gen_efficientnet( 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model def efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. """ model = _gen_efficientnet_edge( 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. """ model = _gen_efficientnet_edge( 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. """ model = _gen_efficientnet_edge( 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def efficientnet_cc_b0_4e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_efficientnet_condconv( 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_cc_b0_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_efficientnet_condconv( 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_efficientnet_condconv( 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def efficientnet_lite0(pretrained=False, **kwargs): """ EfficientNet-Lite0 """ model = _gen_efficientnet_lite( 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def efficientnet_lite1(pretrained=False, **kwargs): """ EfficientNet-Lite1 """ model = _gen_efficientnet_lite( 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def efficientnet_lite2(pretrained=False, **kwargs): """ EfficientNet-Lite2 """ model = _gen_efficientnet_lite( 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def efficientnet_lite3(pretrained=False, **kwargs): """ EfficientNet-Lite3 """ model = _gen_efficientnet_lite( 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def efficientnet_lite4(pretrained=False, **kwargs): """ EfficientNet-Lite4 """ model = _gen_efficientnet_lite( 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b1(pretrained=False, **kwargs): """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b2(pretrained=False, **kwargs): """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b3(pretrained=False, **kwargs): """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b4(pretrained=False, **kwargs): """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b5(pretrained=False, **kwargs): """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b6(pretrained=False, **kwargs): """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b7(pretrained=False, **kwargs): """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b8(pretrained=False, **kwargs): """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b0_ap(pretrained=False, **kwargs): """ EfficientNet-B0 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b1_ap(pretrained=False, **kwargs): """ EfficientNet-B1 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b2_ap(pretrained=False, **kwargs): """ EfficientNet-B2 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b3_ap(pretrained=False, **kwargs): """ EfficientNet-B3 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b4_ap(pretrained=False, **kwargs): """ EfficientNet-B4 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b5_ap(pretrained=False, **kwargs): """ EfficientNet-B5 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b6_ap(pretrained=False, **kwargs): """ EfficientNet-B6 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b7_ap(pretrained=False, **kwargs): """ EfficientNet-B7 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b8_ap(pretrained=False, **kwargs): """ EfficientNet-B8 AdvProp. Tensorflow compatible variant Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b0_ns(pretrained=False, **kwargs): """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b1_ns(pretrained=False, **kwargs): """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b2_ns(pretrained=False, **kwargs): """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b3_ns(pretrained=False, **kwargs): """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b4_ns(pretrained=False, **kwargs): """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b5_ns(pretrained=False, **kwargs): """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b6_ns(pretrained=False, **kwargs): """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) return model def tf_efficientnet_b7_ns(pretrained=False, **kwargs): """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model def tf_efficientnet_l2_ns(pretrained=False, **kwargs): """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model def tf_efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_em(pretrained=False, **kwargs): """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_el(pretrained=False, **kwargs): """ EfficientNet-Edge-Large. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_edge( 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 4 Experts """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B0 w/ 8 Experts """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): """ EfficientNet-CondConv-B1 w/ 8 Experts """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_condconv( 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite0(pretrained=False, **kwargs): """ EfficientNet-Lite0. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite1(pretrained=False, **kwargs): """ EfficientNet-Lite1. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite2(pretrained=False, **kwargs): """ EfficientNet-Lite2. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite3(pretrained=False, **kwargs): """ EfficientNet-Lite3. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) return model def tf_efficientnet_lite4(pretrained=False, **kwargs): """ EfficientNet-Lite4. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet_lite( 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) return model def mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. """ # NOTE for train set drop_rate=0.2 model = _gen_mixnet_s( 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. """ # NOTE for train set drop_rate=0.25 model = _gen_mixnet_m( 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. """ # NOTE for train set drop_rate=0.25 model = _gen_mixnet_m( 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model def mixnet_xl(pretrained=False, **kwargs): """Creates a MixNet Extra-Large model. Not a paper spec, experimental def by RW w/ depth scaling. """ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 model = _gen_mixnet_m( 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) return model def mixnet_xxl(pretrained=False, **kwargs): """Creates a MixNet Double Extra Large model. Not a paper spec, experimental def by RW w/ depth scaling. """ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 model = _gen_mixnet_m( 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) return model def tf_mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_s( 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_mixnet_m(pretrained=False, **kwargs): """Creates a MixNet Medium model. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) return model def tf_mixnet_l(pretrained=False, **kwargs): """Creates a MixNet Large model. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mixnet_m( 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py ================================================ """ Checkpoint loading / state_dict helpers Copyright 2020 Ross Wightman """ import torch import os from collections import OrderedDict try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url def load_checkpoint(model, checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path): print("=> Loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k.startswith('module'): name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) print("=> Loaded checkpoint '{}'".format(checkpoint_path)) else: print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError def load_pretrained(model, url, filter_fn=None, strict=True): if not url: print("=> Warning: Pretrained model URL is empty, using random initialization.") return state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') input_conv = 'conv_stem' classifier = 'classifier' in_chans = getattr(model, input_conv).weight.shape[1] num_classes = getattr(model, classifier).weight.shape[0] input_conv_weight = input_conv + '.weight' pretrained_in_chans = state_dict[input_conv_weight].shape[1] if in_chans != pretrained_in_chans: if in_chans == 1: print('=> Converting pretrained input conv {} from {} to 1 channel'.format( input_conv_weight, pretrained_in_chans)) conv1_weight = state_dict[input_conv_weight] state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) else: print('=> Discarding pretrained input conv {} since input channel count != {}'.format( input_conv_weight, pretrained_in_chans)) del state_dict[input_conv_weight] strict = False classifier_weight = classifier + '.weight' pretrained_num_classes = state_dict[classifier_weight].shape[0] if num_classes != pretrained_num_classes: print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) del state_dict[classifier_weight] del state_dict[classifier + '.bias'] strict = False if filter_fn is not None: state_dict = filter_fn(state_dict) model.load_state_dict(state_dict, strict=strict) ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py ================================================ """ MobileNet-V3 A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by / Copyright 2020 Ross Wightman """ import torch.nn as nn import torch.nn.functional as F from .activations import get_act_fn, get_act_layer, HardSwish from .config import layer_config_kwargs from .conv2d_layers import select_conv2d from .helpers import load_pretrained from .efficientnet_builder import * __all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] model_urls = { 'mobilenetv3_rw': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', 'mobilenetv3_large_075': None, 'mobilenetv3_large_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', 'mobilenetv3_large_minimal_100': None, 'mobilenetv3_small_075': None, 'mobilenetv3_small_100': None, 'mobilenetv3_small_minimal_100': None, 'tf_mobilenetv3_large_075': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', 'tf_mobilenetv3_large_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', 'tf_mobilenetv3_large_minimal_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', 'tf_mobilenetv3_small_075': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', 'tf_mobilenetv3_small_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', 'tf_mobilenetv3_small_minimal_100': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', } class MobileNetV3(nn.Module): """ MobileNet-V3 A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the head convolution without a final batch-norm layer before the classifier. Paper: https://arxiv.org/abs/1905.02244 """ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): super(MobileNetV3, self).__init__() self.drop_rate = drop_rate stem_size = round_channels(stem_size, channel_multiplier) self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) in_chs = stem_size builder = EfficientNetBuilder( channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) self.blocks = nn.Sequential(*builder(in_chs, block_args)) in_chs = builder.in_chs self.global_pool = nn.AdaptiveAvgPool2d(1) self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) self.classifier = nn.Linear(num_features, num_classes) for m in self.modules(): if weight_init == 'goog': initialize_weight_goog(m) else: initialize_weight_default(m) def as_sequential(self): layers = [self.conv_stem, self.bn1, self.act1] layers.extend(self.blocks) layers.extend([ self.global_pool, self.conv_head, self.act2, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) def features(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) x = self.blocks(x) x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) return x def forward(self, x): x = self.features(x) x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) def _create_model(model_kwargs, variant, pretrained=False): as_sequential = model_kwargs.pop('as_sequential', False) model = MobileNetV3(**model_kwargs) if pretrained and model_urls[variant]: load_pretrained(model, model_urls[variant]) if as_sequential: model = model.as_sequential() return model def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MobileNet-V3 model (RW variant). Paper: https://arxiv.org/abs/1905.02244 This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the eventual Tensorflow reference impl but has a few differences: 1. This model has no bias on the head convolution 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer from their parent block 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count Overall the changes are fairly minor and result in a very small parameter count difference and no top-1/5 Args: channel_multiplier: multiplier to number of channels per layer. """ arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish # stage 5, 14x14in ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), head_bias=False, # one of my mistakes channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'hard_swish'), se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MobileNet-V3 large/small/minimal models. Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py Paper: https://arxiv.org/abs/1905.02244 Args: channel_multiplier: multiplier to number of channels per layer. """ if 'small' in variant: num_features = 1024 if 'minimal' in variant: act_layer = 'relu' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s2_e1_c16'], # stage 1, 56x56 in ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], # stage 2, 28x28 in ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], # stage 3, 14x14 in ['ir_r2_k3_s1_e3_c48'], # stage 4, 14x14in ['ir_r3_k3_s2_e6_c96'], # stage 6, 7x7 in ['cn_r1_k1_s1_c576'], ] else: act_layer = 'hard_swish' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu # stage 1, 56x56 in ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu # stage 2, 28x28 in ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish # stage 3, 14x14 in ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish # stage 4, 14x14in ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c576'], # hard-swish ] else: num_features = 1280 if 'minimal' in variant: act_layer = 'relu' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16'], # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], # stage 2, 56x56 in ['ir_r3_k3_s2_e3_c40'], # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112'], # stage 5, 14x14in ['ir_r3_k3_s2_e6_c160'], # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], ] else: act_layer = 'hard_swish' arch_def = [ # stage 0, 112x112 in ['ds_r1_k3_s1_e1_c16_nre'], # relu # stage 1, 112x112 in ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu # stage 2, 56x56 in ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu # stage 3, 28x28 in ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish # stage 4, 14x14in ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish # stage 5, 14x14in ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] with layer_config_kwargs(kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, stem_size=16, channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, act_layer), se_kwargs=dict( act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) model = _create_model(model_kwargs, variant, pretrained) return model def mobilenetv3_rw(pretrained=False, **kwargs): """ MobileNet-V3 RW Attn: See note in gen function for this variant. """ # NOTE for train set drop_rate=0.2 if pretrained: # pretrained model trained with non-default BN epsilon kwargs['bn_eps'] = BN_EPS_TF_DEFAULT model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_075(pretrained=False, **kwargs): """ MobileNet V3 Large 0.75""" # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_100(pretrained=False, **kwargs): """ MobileNet V3 Large 1.0 """ # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Large (Minimalistic) 1.0 """ # NOTE for train set drop_rate=0.2 model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_small_075(pretrained=False, **kwargs): """ MobileNet V3 Small 0.75 """ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) return model def mobilenetv3_small_100(pretrained=False, **kwargs): """ MobileNet V3 Small 1.0 """ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Small (Minimalistic) 1.0 """ model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_large_075(pretrained=False, **kwargs): """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_large_100(pretrained=False, **kwargs): """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_small_075(pretrained=False, **kwargs): """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_small_100(pretrained=False, **kwargs): """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) return model ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py ================================================ from .config import set_layer_config from .helpers import load_checkpoint from .gen_efficientnet import * from .mobilenetv3 import * def create_model( model_name='mnasnet_100', pretrained=None, num_classes=1000, in_chans=3, checkpoint_path='', **kwargs): model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) if model_name in globals(): create_fn = globals()[model_name] model = create_fn(**model_kwargs) else: raise RuntimeError('Unknown model (%s)' % model_name) if checkpoint_path and not pretrained: load_checkpoint(model, checkpoint_path) return model ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py ================================================ __version__ = '1.0.2' ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/hubconf.py ================================================ dependencies = ['torch', 'math'] from geffnet import efficientnet_b0 from geffnet import efficientnet_b1 from geffnet import efficientnet_b2 from geffnet import efficientnet_b3 from geffnet import efficientnet_es from geffnet import efficientnet_lite0 from geffnet import mixnet_s from geffnet import mixnet_m from geffnet import mixnet_l from geffnet import mixnet_xl from geffnet import mobilenetv2_100 from geffnet import mobilenetv2_110d from geffnet import mobilenetv2_120d from geffnet import mobilenetv2_140 from geffnet import mobilenetv3_large_100 from geffnet import mobilenetv3_rw from geffnet import mnasnet_a1 from geffnet import mnasnet_b1 from geffnet import fbnetc_100 from geffnet import spnasnet_100 from geffnet import tf_efficientnet_b0 from geffnet import tf_efficientnet_b1 from geffnet import tf_efficientnet_b2 from geffnet import tf_efficientnet_b3 from geffnet import tf_efficientnet_b4 from geffnet import tf_efficientnet_b5 from geffnet import tf_efficientnet_b6 from geffnet import tf_efficientnet_b7 from geffnet import tf_efficientnet_b8 from geffnet import tf_efficientnet_b0_ap from geffnet import tf_efficientnet_b1_ap from geffnet import tf_efficientnet_b2_ap from geffnet import tf_efficientnet_b3_ap from geffnet import tf_efficientnet_b4_ap from geffnet import tf_efficientnet_b5_ap from geffnet import tf_efficientnet_b6_ap from geffnet import tf_efficientnet_b7_ap from geffnet import tf_efficientnet_b8_ap from geffnet import tf_efficientnet_b0_ns from geffnet import tf_efficientnet_b1_ns from geffnet import tf_efficientnet_b2_ns from geffnet import tf_efficientnet_b3_ns from geffnet import tf_efficientnet_b4_ns from geffnet import tf_efficientnet_b5_ns from geffnet import tf_efficientnet_b6_ns from geffnet import tf_efficientnet_b7_ns from geffnet import tf_efficientnet_l2_ns_475 from geffnet import tf_efficientnet_l2_ns from geffnet import tf_efficientnet_es from geffnet import tf_efficientnet_em from geffnet import tf_efficientnet_el from geffnet import tf_efficientnet_cc_b0_4e from geffnet import tf_efficientnet_cc_b0_8e from geffnet import tf_efficientnet_cc_b1_8e from geffnet import tf_efficientnet_lite0 from geffnet import tf_efficientnet_lite1 from geffnet import tf_efficientnet_lite2 from geffnet import tf_efficientnet_lite3 from geffnet import tf_efficientnet_lite4 from geffnet import tf_mixnet_s from geffnet import tf_mixnet_m from geffnet import tf_mixnet_l from geffnet import tf_mobilenetv3_large_075 from geffnet import tf_mobilenetv3_large_100 from geffnet import tf_mobilenetv3_large_minimal_100 from geffnet import tf_mobilenetv3_small_075 from geffnet import tf_mobilenetv3_small_100 from geffnet import tf_mobilenetv3_small_minimal_100 ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/requirements.txt ================================================ torch>=1.2.0 torchvision>=0.4.0 ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/setup.py ================================================ """ Setup """ from setuptools import setup, find_packages from codecs import open from os import path here = path.abspath(path.dirname(__file__)) __version__ = '0.0.0' # Get the long description from the README file with open(path.join(here, 'README.md'), encoding='utf-8') as f: long_description = f.read() exec(open('geffnet/version.py').read()) setup( name='geffnet', version=__version__, description='(Generic) EfficientNets for PyTorch', long_description=long_description, long_description_content_type='text/markdown', url='https://github.com/rwightman/gen-efficientnet-pytorch', author='Ross Wightman', author_email='hello@rwightman.com', classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable 'Development Status :: 3 - Alpha', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ], # Note that this is a string of words separated by whitespace, not a list. keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', packages=find_packages(exclude=['data']), install_requires=['torch >= 1.4', 'torchvision'], python_requires='>=3.6', ) ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/utils.py ================================================ import os class AverageMeter: """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res def get_outdir(path, *paths, inc=False): outdir = os.path.join(path, *paths) if not os.path.exists(outdir): os.makedirs(outdir) elif inc: count = 1 outdir_inc = outdir + '-' + str(count) while os.path.exists(outdir_inc): count = count + 1 outdir_inc = outdir + '-' + str(count) assert count < 100 outdir = outdir_inc os.makedirs(outdir) return outdir ================================================ FILE: modules/control/proc/normalbae/nets/submodules/efficientnet_repo/validate.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import time import torch import torch.nn as nn import torch.nn.parallel from contextlib import suppress import geffnet from data import Dataset, create_loader, resolve_data_config from utils import accuracy, AverageMeter has_native_amp = False try: if torch.cuda.amp.autocast is not None: has_native_amp = True except AttributeError: pass torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', help='Override default crop pct of 0.875') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', help='use tensorflow mnasnet preporcessing') parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', help='') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, help='Use native Torch AMP mixed precision.') def main(): args = parser.parse_args() if not args.checkpoint and not args.pretrained: args.pretrained = True amp_autocast = suppress # do nothing if args.amp: if not has_native_amp: print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.") else: amp_autocast = torch.cuda.amp.autocast # create model model = geffnet.create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained, checkpoint_path=args.checkpoint, scriptable=args.torchscript) if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) print('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(model, args) criterion = nn.CrossEntropyLoss() if not args.no_cuda: if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() criterion = criterion.cuda() loader = create_loader( Dataset(args.data, load_bytes=args.tf_preprocessing), input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=not args.no_cuda, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=data_config['crop_pct'], tensorflow_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() for i, (input, target) in enumerate(loader): if not args.no_cuda: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) if __name__ == '__main__': main() ================================================ FILE: modules/control/proc/normalbae/nets/submodules/encoder.py ================================================ import os import torch import torch.nn as nn class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() basemodel_name = 'tf_efficientnet_b5_ap' repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') # Remove last layer basemodel.global_pool = nn.Identity() basemodel.classifier = nn.Identity() self.original_model = basemodel def forward(self, x): features = [x] for k, v in self.original_model._modules.items(): if k == 'blocks': for _ki, vi in v._modules.items(): features.append(vi(features[-1])) else: features.append(v(features[-1])) return features ================================================ FILE: modules/control/proc/normalbae/nets/submodules/submodules.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F ######################################################################################################################## # Upsample + BatchNorm class UpSampleBN(nn.Module): def __init__(self, skip_input, output_features): super(UpSampleBN, self).__init__() self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU(), nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU()) def forward(self, x, concat_with): up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) f = torch.cat([up_x, concat_with], dim=1) return self._net(f) # Upsample + GroupNorm + Weight Standardization class UpSampleGN(nn.Module): def __init__(self, skip_input, output_features): super(UpSampleGN, self).__init__() self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU(), Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU()) def forward(self, x, concat_with): up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) f = torch.cat([up_x, concat_with], dim=1) return self._net(f) # Conv2d with weight standardization class Conv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) def forward(self, x): weight = self.weight weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) weight = weight - weight_mean std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 weight = weight / std.expand_as(weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) # normalize def norm_normalize(norm_out): min_kappa = 0.01 norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 kappa = F.elu(kappa) + 1.0 + min_kappa final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) return final_out # uncertainty-guided sampling (only used during training) def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): device = init_normal.device B, _, H, W = init_normal.shape N = int(sampling_ratio * H * W) beta = beta # uncertainty map uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W # gt_invalid_mask (B, H, W) if gt_norm_mask is not None: gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 uncertainty_map[gt_invalid_mask] = -1e4 # (B, H*W) _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) # importance sampling if int(beta * N) > 0: importance = idx[:, :int(beta * N)] # B, beta*N # remaining remaining = idx[:, int(beta * N):] # B, H*W - beta*N # coverage num_coverage = N - int(beta * N) if num_coverage <= 0: samples = importance else: coverage_list = [] for i in range(B): idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N samples = torch.cat((importance, coverage), dim=1) # B, N else: # remaining remaining = idx[:, :] # B, H*W # coverage num_coverage = N coverage_list = [] for i in range(B): idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N samples = coverage # point coordinates rows_int = samples // W # 0 for first row, H-1 for last row rows_float = rows_int / float(H-1) # 0 to 1.0 rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 cols_int = samples % W # 0 for first column, W-1 for last column cols_float = cols_int / float(W-1) # 0 to 1.0 cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 point_coords = torch.zeros(B, 1, N, 2) point_coords[:, 0, :, 0] = cols_float # x coord point_coords[:, 0, :, 1] = rows_float # y coord point_coords = point_coords.to(device) return point_coords, rows_int, cols_int ================================================ FILE: modules/control/proc/openpose/LICENSE ================================================ OPENPOSE: MULTIPERSON KEYPOINT DETECTION SOFTWARE LICENSE AGREEMENT ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. COPYRIGHT: The Software is owned by Licensor and is protected by United States copyright laws and applicable international treaties and/or conventions. PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. EXPORT REGULATION: Licensee agrees to comply with any and all applicable U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. ************************************************************************ THIRD-PARTY SOFTWARE NOTICES AND INFORMATION This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. 1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) COPYRIGHT All contributions by the University of California: Copyright (c) 2014-2017 The Regents of the University of California (Regents) All rights reserved. All other contributions: Copyright (c) 2014-2017, the respective contributors All rights reserved. Caffe uses a shared copyright model: each contributor holds copyright over their contributions to Caffe. The project versioning records all such contribution and copyright details. If a contributor wants to further mark their specific copyright on a particular contribution, they should indicate their copyright solely in the commit message of the change when it is committed. LICENSE Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. CONTRIBUTION AGREEMENT By contributing to the BVLC/caffe repository through pull-request, comment, or otherwise, the contributor releases their content to the license and copyright terms herein. ************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** ================================================ FILE: modules/control/proc/openpose/__init__.py ================================================ # Openpose # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose # 2nd Edited by https://github.com/Hzzone/pytorch-openpose # 3rd Edited by ControlNet # 4th Edited by ControlNet (added face and correct hands) # 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) # This preprocessor is licensed by CMU for non-commercial use only. import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" import warnings from typing import List, NamedTuple, Tuple, Union import cv2 import numpy as np from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image from . import util from .body import Body, BodyResult, Keypoint from .face import Face from .hand import Hand HandResult = List[Keypoint] FaceResult = List[Keypoint] class PoseResult(NamedTuple): body: BodyResult left_hand: Union[HandResult, None] right_hand: Union[HandResult, None] face: Union[FaceResult, None] def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): """ Draw the detected poses on an empty canvas. Args: poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. H (int): The height of the canvas. W (int): The width of the canvas. draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. Returns: numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. """ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) for pose in poses: if draw_body: canvas = util.draw_bodypose(canvas, pose.body.keypoints) if draw_hand: canvas = util.draw_handpose(canvas, pose.left_hand) canvas = util.draw_handpose(canvas, pose.right_hand) if draw_face: canvas = util.draw_facepose(canvas, pose.face) return canvas class OpenposeDetector: """ A class for detecting human poses in images using the Openpose model. Attributes: model_dir (str): Path to the directory where the pose models are stored. """ def __init__(self, body_estimation, hand_estimation=None, face_estimation=None): self.body_estimation = body_estimation self.hand_estimation = hand_estimation self.face_estimation = face_estimation @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, hand_filename=None, face_filename=None, cache_dir=None, local_files_only=False): if pretrained_model_or_path == "lllyasviel/ControlNet": filename = filename or "annotator/ckpts/body_pose_model.pth" hand_filename = hand_filename or "annotator/ckpts/hand_pose_model.pth" face_filename = face_filename or "facenet.pth" face_pretrained_model_or_path = "lllyasviel/Annotators" else: filename = filename or "body_pose_model.pth" hand_filename = hand_filename or "hand_pose_model.pth" face_filename = face_filename or "facenet.pth" face_pretrained_model_or_path = pretrained_model_or_path if os.path.isdir(pretrained_model_or_path): body_model_path = os.path.join(pretrained_model_or_path, filename) hand_model_path = os.path.join(pretrained_model_or_path, hand_filename) face_model_path = os.path.join(face_pretrained_model_or_path, face_filename) else: body_model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) hand_model_path = hf_hub_download(pretrained_model_or_path, hand_filename, cache_dir=cache_dir, local_files_only=local_files_only) face_model_path = hf_hub_download(face_pretrained_model_or_path, face_filename, cache_dir=cache_dir, local_files_only=local_files_only) body_estimation = Body(body_model_path) hand_estimation = Hand(hand_model_path) face_estimation = Face(face_model_path) return cls(body_estimation, hand_estimation, face_estimation) def to(self, device): self.body_estimation.to(device) self.hand_estimation.to(device) self.face_estimation.to(device) return self def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]: left_hand = None right_hand = None H, W, _ = oriImg.shape for x, y, w, is_left in util.handDetect(body, oriImg): peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) if peaks.ndim == 2 and peaks.shape[1] == 2: peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) hand_result = [ Keypoint(x=peak[0], y=peak[1]) for peak in peaks ] if is_left: left_hand = hand_result else: right_hand = hand_result return left_hand, right_hand def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: face = util.faceDetect(body, oriImg) if face is None: return None x, y, w = face H, W, _ = oriImg.shape heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) if peaks.ndim == 2 and peaks.shape[1] == 2: peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) return [ Keypoint(x=peak[0], y=peak[1]) for peak in peaks ] return None def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: """ Detect poses in the given image. Args: oriImg (numpy.ndarray): The input image for pose detection. include_hand (bool, optional): Whether to include hand detection. Defaults to False. include_face (bool, optional): Whether to include face detection. Defaults to False. Returns: List[PoseResult]: A list of PoseResult objects containing the detected poses. """ oriImg = oriImg[:, :, ::-1].copy() H, W, _C = oriImg.shape candidate, subset = self.body_estimation(oriImg) bodies = self.body_estimation.format_body_result(candidate, subset) results = [] for body in bodies: left_hand, right_hand, face = (None,) * 3 if include_hand: left_hand, right_hand = self.detect_hands(body, oriImg) if include_face: face = self.detect_face(body, oriImg) results.append(PoseResult(BodyResult( keypoints=[ Keypoint( x=keypoint.x / float(W), y=keypoint.y / float(H) ) if keypoint is not None else None for keypoint in body.keypoints ], total_score=body.total_score, total_parts=body.total_parts ), left_hand, right_hand, face)) return results def __call__(self, input_image, detect_resolution=512, image_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", **kwargs): self.to(devices.device) if hand_and_face is not None: warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) include_hand = hand_and_face include_face = hand_and_face if "return_pil" in kwargs: warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) output_type = "pil" if kwargs["return_pil"] else "np" if type(output_type) is bool: warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") if output_type: output_type = "pil" if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) H, W, _C = input_image.shape poses = self.detect_poses(input_image, include_hand, include_face) canvas = draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face) detected_map = canvas detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if opts.control_move_processor: self.to('cpu') if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/openpose/body.py ================================================ import math from typing import List, NamedTuple, Union import numpy as np import torch from scipy.ndimage.filters import gaussian_filter from . import util from .model import bodypose_model class Keypoint(NamedTuple): x: float y: float score: float = 1.0 id: int = -1 class BodyResult(NamedTuple): # Note: Using `Union` instead of `|` operator as the ladder is a Python # 3.10 feature. # Annotator code should be Python 3.8 Compatible, as controlnet repo uses # Python 3.8 environment. # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 keypoints: List[Union[Keypoint, None]] total_score: float total_parts: int class Body(object): def __init__(self, model_path): self.model = bodypose_model() model_dict = util.transfer(self.model, torch.load(model_path)) self.model.load_state_dict(model_dict) self.model.eval() def to(self, device): self.model.to(device) return self def __call__(self, oriImg): device = next(iter(self.model.parameters())).device # scale_search = [0.5, 1.0, 1.5, 2.0] scale_search = [0.5] boxsize = 368 stride = 8 padValue = 128 thre1 = 0.1 thre2 = 0.05 multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) for m in range(len(multiplier)): scale = multiplier[m] imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 im = np.ascontiguousarray(im) data = torch.from_numpy(im).float() data = data.to(device) # data = data.permute([2, 0, 1]).unsqueeze(0).float() Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() # extract outputs, resize, and remove padding # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs paf = util.smart_resize_k(paf, fx=stride, fy=stride) paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) heatmap_avg += heatmap_avg + heatmap / len(multiplier) paf_avg += paf / len(multiplier) all_peaks = [] peak_counter = 0 for part in range(18): map_ori = heatmap_avg[:, :, part] one_heatmap = gaussian_filter(map_ori, sigma=3) map_left = np.zeros(one_heatmap.shape) map_left[1:, :] = one_heatmap[:-1, :] map_right = np.zeros(one_heatmap.shape) map_right[:-1, :] = one_heatmap[1:, :] map_up = np.zeros(one_heatmap.shape) map_up[:, 1:] = one_heatmap[:, :-1] map_down = np.zeros(one_heatmap.shape) map_down[:, :-1] = one_heatmap[:, 1:] peaks_binary = np.logical_and.reduce( (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] peak_id = range(peak_counter, peak_counter + len(peaks)) peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] all_peaks.append(peaks_with_score_and_id) peak_counter += len(peaks) # find connection in the specified sequence, center 29 is in the position 15 limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ [1, 16], [16, 18], [3, 17], [6, 18]] # the middle joints heatmap correpondence mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ [55, 56], [37, 38], [45, 46]] connection_all = [] special_k = [] mid_num = 10 for k in range(len(mapIdx)): score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] candA = all_peaks[limbSeq[k][0] - 1] candB = all_peaks[limbSeq[k][1] - 1] nA = len(candA) nB = len(candB) indexA, indexB = limbSeq[k] if (nA != 0 and nB != 0): connection_candidate = [] for i in range(nA): for j in range(nB): vec = np.subtract(candB[j][:2], candA[i][:2]) norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) norm = max(0.001, norm) vec = np.divide(vec, norm) startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ np.linspace(candA[i][1], candB[j][1], num=mid_num))) vec_x = np.array([score_mid[int(round(startend[x][1])), int(round(startend[x][0])), 0] for x in range(len(startend))]) vec_y = np.array([score_mid[int(round(startend[x][1])), int(round(startend[x][0])), 1] for x in range(len(startend))]) score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( 0.5 * oriImg.shape[0] / norm - 1, 0) criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) criterion2 = score_with_dist_prior > 0 if criterion1 and criterion2: connection_candidate.append( [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) connection = np.zeros((0, 5)) for c in range(len(connection_candidate)): i, j, s = connection_candidate[c][0:3] if (i not in connection[:, 3] and j not in connection[:, 4]): connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) if len(connection) >= min(nA, nB): break connection_all.append(connection) else: special_k.append(k) connection_all.append([]) # last number in each row is the total parts number of that person # the second last number in each row is the score of the overall configuration subset = -1 * np.ones((0, 20)) candidate = np.array([item for sublist in all_peaks for item in sublist]) for k in range(len(mapIdx)): if k not in special_k: partAs = connection_all[k][:, 0] partBs = connection_all[k][:, 1] indexA, indexB = np.array(limbSeq[k]) - 1 for i in range(len(connection_all[k])): # = 1:size(temp,1) found = 0 subset_idx = [-1, -1] for j in range(len(subset)): # 1:size(subset,1): if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: subset_idx[found] = j found += 1 if found == 1: j = subset_idx[0] if subset[j][indexB] != partBs[i]: subset[j][indexB] = partBs[i] subset[j][-1] += 1 subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] elif found == 2: # if found 2 and disjoint, merge them j1, j2 = subset_idx membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] if len(np.nonzero(membership == 2)[0]) == 0: # merge subset[j1][:-2] += (subset[j2][:-2] + 1) subset[j1][-2:] += subset[j2][-2:] subset[j1][-2] += connection_all[k][i][2] subset = np.delete(subset, j2, 0) else: # as like found == 1 subset[j1][indexB] = partBs[i] subset[j1][-1] += 1 subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] # if find no partA in the subset, create a new subset elif not found and k < 17: row = -1 * np.ones(20) row[indexA] = partAs[i] row[indexB] = partBs[i] row[-1] = 2 row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] subset = np.vstack([subset, row]) # delete some rows of subset which has few parts occur deleteIdx = [] for i in range(len(subset)): if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: deleteIdx.append(i) subset = np.delete(subset, deleteIdx, axis=0) # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts # candidate: x, y, score, id return candidate, subset @staticmethod def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: """ Format the body results from the candidate and subset arrays into a list of BodyResult objects. Args: candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id for each body part. subset (np.ndarray): An array of subsets containing indices to the candidate array for each person detected. The last two columns of each row hold the total score and total parts of the person. Returns: List[BodyResult]: A list of BodyResult objects, where each object represents a person with detected keypoints, total score, and total parts. """ return [ BodyResult( keypoints=[ Keypoint( x=candidate[candidate_index][0], y=candidate[candidate_index][1], score=candidate[candidate_index][2], id=candidate[candidate_index][3] ) if candidate_index != -1 else None for candidate_index in person[:18].astype(int) ], total_score=person[18], total_parts=person[19] ) for person in subset ] ================================================ FILE: modules/control/proc/openpose/face.py ================================================ import numpy as np import torch import torch.nn.functional as F from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init from torchvision.transforms import ToPILImage, ToTensor from . import util class FaceNet(Module): """Model the cascading heatmaps. """ def __init__(self): super(FaceNet, self).__init__() # cnn to make feature map self.relu = ReLU() self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) self.conv1_1 = Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) self.conv1_2 = Conv2d( in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1) self.conv2_1 = Conv2d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1) self.conv2_2 = Conv2d( in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1) self.conv3_1 = Conv2d( in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1) self.conv3_2 = Conv2d( in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1) self.conv3_3 = Conv2d( in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1) self.conv3_4 = Conv2d( in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1) self.conv4_1 = Conv2d( in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) self.conv4_2 = Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) self.conv4_3 = Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) self.conv4_4 = Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) self.conv5_1 = Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) self.conv5_2 = Conv2d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1) self.conv5_3_CPM = Conv2d( in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=1) # stage1 self.conv6_1_CPM = Conv2d( in_channels=128, out_channels=512, kernel_size=1, stride=1, padding=0) self.conv6_2_CPM = Conv2d( in_channels=512, out_channels=71, kernel_size=1, stride=1, padding=0) # stage2 self.Mconv1_stage2 = Conv2d( in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv2_stage2 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv3_stage2 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv4_stage2 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv5_stage2 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv6_stage2 = Conv2d( in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) self.Mconv7_stage2 = Conv2d( in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage3 self.Mconv1_stage3 = Conv2d( in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv2_stage3 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv3_stage3 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv4_stage3 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv5_stage3 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv6_stage3 = Conv2d( in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) self.Mconv7_stage3 = Conv2d( in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage4 self.Mconv1_stage4 = Conv2d( in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv2_stage4 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv3_stage4 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv4_stage4 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv5_stage4 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv6_stage4 = Conv2d( in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) self.Mconv7_stage4 = Conv2d( in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage5 self.Mconv1_stage5 = Conv2d( in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv2_stage5 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv3_stage5 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv4_stage5 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv5_stage5 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv6_stage5 = Conv2d( in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) self.Mconv7_stage5 = Conv2d( in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) # stage6 self.Mconv1_stage6 = Conv2d( in_channels=199, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv2_stage6 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv3_stage6 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv4_stage6 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv5_stage6 = Conv2d( in_channels=128, out_channels=128, kernel_size=7, stride=1, padding=3) self.Mconv6_stage6 = Conv2d( in_channels=128, out_channels=128, kernel_size=1, stride=1, padding=0) self.Mconv7_stage6 = Conv2d( in_channels=128, out_channels=71, kernel_size=1, stride=1, padding=0) for m in self.modules(): if isinstance(m, Conv2d): init.constant_(m.bias, 0) def forward(self, x): """Return a list of heatmaps.""" heatmaps = [] h = self.relu(self.conv1_1(x)) h = self.relu(self.conv1_2(h)) h = self.max_pooling_2d(h) h = self.relu(self.conv2_1(h)) h = self.relu(self.conv2_2(h)) h = self.max_pooling_2d(h) h = self.relu(self.conv3_1(h)) h = self.relu(self.conv3_2(h)) h = self.relu(self.conv3_3(h)) h = self.relu(self.conv3_4(h)) h = self.max_pooling_2d(h) h = self.relu(self.conv4_1(h)) h = self.relu(self.conv4_2(h)) h = self.relu(self.conv4_3(h)) h = self.relu(self.conv4_4(h)) h = self.relu(self.conv5_1(h)) h = self.relu(self.conv5_2(h)) h = self.relu(self.conv5_3_CPM(h)) feature_map = h # stage1 h = self.relu(self.conv6_1_CPM(h)) h = self.conv6_2_CPM(h) heatmaps.append(h) # stage2 h = torch.cat([h, feature_map], dim=1) # channel concat h = self.relu(self.Mconv1_stage2(h)) h = self.relu(self.Mconv2_stage2(h)) h = self.relu(self.Mconv3_stage2(h)) h = self.relu(self.Mconv4_stage2(h)) h = self.relu(self.Mconv5_stage2(h)) h = self.relu(self.Mconv6_stage2(h)) h = self.Mconv7_stage2(h) heatmaps.append(h) # stage3 h = torch.cat([h, feature_map], dim=1) # channel concat h = self.relu(self.Mconv1_stage3(h)) h = self.relu(self.Mconv2_stage3(h)) h = self.relu(self.Mconv3_stage3(h)) h = self.relu(self.Mconv4_stage3(h)) h = self.relu(self.Mconv5_stage3(h)) h = self.relu(self.Mconv6_stage3(h)) h = self.Mconv7_stage3(h) heatmaps.append(h) # stage4 h = torch.cat([h, feature_map], dim=1) # channel concat h = self.relu(self.Mconv1_stage4(h)) h = self.relu(self.Mconv2_stage4(h)) h = self.relu(self.Mconv3_stage4(h)) h = self.relu(self.Mconv4_stage4(h)) h = self.relu(self.Mconv5_stage4(h)) h = self.relu(self.Mconv6_stage4(h)) h = self.Mconv7_stage4(h) heatmaps.append(h) # stage5 h = torch.cat([h, feature_map], dim=1) # channel concat h = self.relu(self.Mconv1_stage5(h)) h = self.relu(self.Mconv2_stage5(h)) h = self.relu(self.Mconv3_stage5(h)) h = self.relu(self.Mconv4_stage5(h)) h = self.relu(self.Mconv5_stage5(h)) h = self.relu(self.Mconv6_stage5(h)) h = self.Mconv7_stage5(h) heatmaps.append(h) # stage6 h = torch.cat([h, feature_map], dim=1) # channel concat h = self.relu(self.Mconv1_stage6(h)) h = self.relu(self.Mconv2_stage6(h)) h = self.relu(self.Mconv3_stage6(h)) h = self.relu(self.Mconv4_stage6(h)) h = self.relu(self.Mconv5_stage6(h)) h = self.relu(self.Mconv6_stage6(h)) h = self.Mconv7_stage6(h) heatmaps.append(h) return heatmaps TOTEN = ToTensor() TOPIL = ToPILImage() params = { 'gaussian_sigma': 2.5, 'inference_img_size': 736, # 368, 736, 1312 'heatmap_peak_thresh': 0.1, 'crop_scale': 1.5, 'line_indices': [ [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [22, 23], [23, 24], [24, 25], [25, 26], [27, 28], [28, 29], [29, 30], [31, 32], [32, 33], [33, 34], [34, 35], [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60] ], } class Face(object): """ The OpenPose face landmark detector model. Args: inference_size: set the size of the inference image size, suggested: 368, 736, 1312, default 736 gaussian_sigma: blur the heatmaps, default 2.5 heatmap_peak_thresh: return landmark if over threshold, default 0.1 """ def __init__(self, face_model_path, inference_size=None, gaussian_sigma=None, heatmap_peak_thresh=None): self.inference_size = inference_size or params["inference_img_size"] self.sigma = gaussian_sigma or params['gaussian_sigma'] self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] self.model = FaceNet() self.model.load_state_dict(torch.load(face_model_path)) self.model.eval() def to(self, device): self.model.to(device) return self def __call__(self, face_img): device = next(iter(self.model.parameters())).device H, W, C = face_img.shape w_size = 384 x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 x_data = x_data.to(device) hs = self.model(x_data[None, ...]) heatmaps = F.interpolate( hs[-1], (H, W), mode='bilinear', align_corners=True).cpu().numpy()[0] return heatmaps def compute_peaks_from_heatmaps(self, heatmaps): all_peaks = [] for part in range(heatmaps.shape[0]): map_ori = heatmaps[part].copy() binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) if np.sum(binary) == 0: continue positions = np.where(binary > 0.5) intensities = map_ori[positions] mi = np.argmax(intensities) y, x = positions[0][mi], positions[1][mi] all_peaks.append([x, y]) return np.array(all_peaks) ================================================ FILE: modules/control/proc/openpose/hand.py ================================================ import cv2 import numpy as np import torch from scipy.ndimage.filters import gaussian_filter from skimage.measure import label from . import util from .model import handpose_model class Hand(object): def __init__(self, model_path): self.model = handpose_model() model_dict = util.transfer(self.model, torch.load(model_path)) self.model.load_state_dict(model_dict) self.model.eval() def to(self, device): self.model.to(device) return self def __call__(self, oriImgRaw): device = next(iter(self.model.parameters())).device scale_search = [0.5, 1.0, 1.5, 2.0] # scale_search = [0.5] boxsize = 368 stride = 8 padValue = 128 thre = 0.05 multiplier = [x * boxsize for x in scale_search] wsize = 128 heatmap_avg = np.zeros((wsize, wsize, 22)) Hr, Wr, Cr = oriImgRaw.shape oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) for m in range(len(multiplier)): scale = multiplier[m] imageToTest = util.smart_resize(oriImg, (scale, scale)) imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 im = np.ascontiguousarray(im) data = torch.from_numpy(im).float() data = data.to(device) output = self.model(data).cpu().numpy() # extract outputs, resize, and remove padding heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] heatmap = util.smart_resize(heatmap, (wsize, wsize)) heatmap_avg += heatmap / len(multiplier) all_peaks = [] for part in range(21): map_ori = heatmap_avg[:, :, part] one_heatmap = gaussian_filter(map_ori, sigma=3) binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) if np.sum(binary) == 0: all_peaks.append([0, 0]) continue label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 label_img[label_img != max_index] = 0 map_ori[label_img == 0] = 0 y, x = util.npmax(map_ori) y = int(float(y) * float(Hr) / float(wsize)) x = int(float(x) * float(Wr) / float(wsize)) all_peaks.append([x, y]) return np.array(all_peaks) if __name__ == "__main__": hand_estimation = Hand('../model/hand_pose_model.pth') # test_image = '../images/hand.jpg' test_image = '../images/hand.jpg' oriImg = cv2.imread(test_image) # B,G,R order peaks = hand_estimation(oriImg) canvas = util.draw_handpose(oriImg, peaks, True) cv2.imshow('', canvas) cv2.waitKey(0) ================================================ FILE: modules/control/proc/openpose/model.py ================================================ from collections import OrderedDict import torch import torch.nn as nn def make_layers(block, no_relu_layers): layers = [] for layer_name, v in block.items(): if 'pool' in layer_name: layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) layers.append((layer_name, layer)) else: conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4]) layers.append((layer_name, conv2d)) if layer_name not in no_relu_layers: layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) return nn.Sequential(OrderedDict(layers)) class bodypose_model(nn.Module): def __init__(self): super(bodypose_model, self).__init__() # these layers have no relu layer no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] blocks = {} block0 = OrderedDict([ ('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3_CPM', [512, 256, 3, 1, 1]), ('conv4_4_CPM', [256, 128, 3, 1, 1]) ]) # Stage 1 block1_1 = OrderedDict([ ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) ]) block1_2 = OrderedDict([ ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) ]) blocks['block1_1'] = block1_1 blocks['block1_2'] = block1_2 self.model0 = make_layers(block0, no_relu_layers) # Stages 2 - 6 for i in range(2, 7): blocks['block%d_1' % i] = OrderedDict([ ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) ]) blocks['block%d_2' % i] = OrderedDict([ ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) ]) for k in blocks.keys(): blocks[k] = make_layers(blocks[k], no_relu_layers) self.model1_1 = blocks['block1_1'] self.model2_1 = blocks['block2_1'] self.model3_1 = blocks['block3_1'] self.model4_1 = blocks['block4_1'] self.model5_1 = blocks['block5_1'] self.model6_1 = blocks['block6_1'] self.model1_2 = blocks['block1_2'] self.model2_2 = blocks['block2_2'] self.model3_2 = blocks['block3_2'] self.model4_2 = blocks['block4_2'] self.model5_2 = blocks['block5_2'] self.model6_2 = blocks['block6_2'] def forward(self, x): out1 = self.model0(x) out1_1 = self.model1_1(out1) out1_2 = self.model1_2(out1) out2 = torch.cat([out1_1, out1_2, out1], 1) out2_1 = self.model2_1(out2) out2_2 = self.model2_2(out2) out3 = torch.cat([out2_1, out2_2, out1], 1) out3_1 = self.model3_1(out3) out3_2 = self.model3_2(out3) out4 = torch.cat([out3_1, out3_2, out1], 1) out4_1 = self.model4_1(out4) out4_2 = self.model4_2(out4) out5 = torch.cat([out4_1, out4_2, out1], 1) out5_1 = self.model5_1(out5) out5_2 = self.model5_2(out5) out6 = torch.cat([out5_1, out5_2, out1], 1) out6_1 = self.model6_1(out6) out6_2 = self.model6_2(out6) return out6_1, out6_2 class handpose_model(nn.Module): def __init__(self): super(handpose_model, self).__init__() # these layers have no relu layer no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] # stage 1 block1_0 = OrderedDict([ ('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3', [512, 512, 3, 1, 1]), ('conv4_4', [512, 512, 3, 1, 1]), ('conv5_1', [512, 512, 3, 1, 1]), ('conv5_2', [512, 512, 3, 1, 1]), ('conv5_3_CPM', [512, 128, 3, 1, 1]) ]) block1_1 = OrderedDict([ ('conv6_1_CPM', [128, 512, 1, 1, 0]), ('conv6_2_CPM', [512, 22, 1, 1, 0]) ]) blocks = {} blocks['block1_0'] = block1_0 blocks['block1_1'] = block1_1 # stage 2-6 for i in range(2, 7): blocks['block%d' % i] = OrderedDict([ ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) ]) for k in blocks.keys(): blocks[k] = make_layers(blocks[k], no_relu_layers) self.model1_0 = blocks['block1_0'] self.model1_1 = blocks['block1_1'] self.model2 = blocks['block2'] self.model3 = blocks['block3'] self.model4 = blocks['block4'] self.model5 = blocks['block5'] self.model6 = blocks['block6'] def forward(self, x): out1_0 = self.model1_0(x) out1_1 = self.model1_1(out1_0) concat_stage2 = torch.cat([out1_1, out1_0], 1) out_stage2 = self.model2(concat_stage2) concat_stage3 = torch.cat([out_stage2, out1_0], 1) out_stage3 = self.model3(concat_stage3) concat_stage4 = torch.cat([out_stage3, out1_0], 1) out_stage4 = self.model4(concat_stage4) concat_stage5 = torch.cat([out_stage4, out1_0], 1) out_stage5 = self.model5(concat_stage5) concat_stage6 = torch.cat([out_stage5, out1_0], 1) out_stage6 = self.model6(concat_stage6) return out_stage6 ================================================ FILE: modules/control/proc/openpose/util.py ================================================ from typing import List, Tuple, Union import math import numpy as np import cv2 from .body import BodyResult, Keypoint eps = 0.01 def smart_resize(x, s): Ht, Wt = s if x.ndim == 2: Ho, Wo = x.shape Co = 1 else: Ho, Wo, Co = x.shape if Co == 3 or Co == 1: k = float(Ht + Wt) / float(Ho + Wo) return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) else: return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) def smart_resize_k(x, fx, fy): if x.ndim == 2: Ho, Wo = x.shape Co = 1 else: Ho, Wo, Co = x.shape Ht, Wt = Ho * fy, Wo * fx if Co == 3 or Co == 1: k = float(Ht + Wt) / float(Ho + Wo) return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) else: return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) def padRightDownCorner(img, stride, padValue): h = img.shape[0] w = img.shape[1] pad = 4 * [None] pad[0] = 0 # up pad[1] = 0 # left pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right img_padded = img pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) img_padded = np.concatenate((pad_up, img_padded), axis=0) pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) img_padded = np.concatenate((pad_left, img_padded), axis=1) pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) img_padded = np.concatenate((img_padded, pad_down), axis=0) pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) img_padded = np.concatenate((img_padded, pad_right), axis=1) return img_padded, pad def transfer(model, model_weights): transfered_model_weights = {} for weights_name in model.state_dict().keys(): transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] return transfered_model_weights def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: """ Draw keypoints and limbs representing body pose on a given canvas. Args: canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose. keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn. Returns: np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose. Note: The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. """ H, W, _C = canvas.shape stickwidth = 4 limbSeq = [ [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], ] colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] for (k1_index, k2_index), color in zip(limbSeq, colors): keypoint1 = keypoints[k1_index - 1] keypoint2 = keypoints[k2_index - 1] if keypoint1 is None or keypoint2 is None: continue Y = np.array([keypoint1.x, keypoint2.x]) * float(W) X = np.array([keypoint1.y, keypoint2.y]) * float(H) mX = np.mean(X) mY = np.mean(Y) length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) for keypoint, color in zip(keypoints, colors): if keypoint is None: continue x, y = keypoint.x, keypoint.y x = int(x * W) y = int(y * H) cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) return canvas def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: import matplotlib as mpl """ Draw keypoints and connections representing hand pose on a given canvas. Args: canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn or None if no keypoints are present. Returns: np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. Note: The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. """ if not keypoints: return canvas H, W, _C = canvas.shape edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] for ie, (e1, e2) in enumerate(edges): k1 = keypoints[e1] k2 = keypoints[e2] if k1 is None or k2 is None: continue x1 = int(k1.x * W) y1 = int(k1.y * H) x2 = int(k2.x * W) y2 = int(k2.y * H) if x1 > eps and y1 > eps and x2 > eps and y2 > eps: cv2.line(canvas, (x1, y1), (x2, y2), mpl.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) for keypoint in keypoints: x, y = keypoint.x, keypoint.y x = int(x * W) y = int(y * H) if x > eps and y > eps: cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) return canvas def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: """ Draw keypoints representing face pose on a given canvas. Args: canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose. keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn or None if no keypoints are present. Returns: np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose. Note: The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. """ if not keypoints: return canvas H, W, _C = canvas.shape for keypoint in keypoints: x, y = keypoint.x, keypoint.y x = int(x * W) y = int(y * H) if x > eps and y > eps: cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) return canvas # detect hand according to body pose keypoints # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: """ Detect hands in the input body pose keypoints and calculate the bounding box for each hand. Args: body (BodyResult): A BodyResult object containing the detected body pose keypoints. oriImg (numpy.ndarray): A 3D numpy array representing the original input image. Returns: List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left corner of the bounding box, the width (height) of the bounding box, and a boolean flag indicating whether the hand is a left hand (True) or a right hand (False). Notes: - The width and height of the bounding boxes are equal since the network requires squared input. - The minimum bounding box size is 20 pixels. """ ratioWristElbow = 0.33 detect_result = [] image_height, image_width = oriImg.shape[0:2] keypoints = body.keypoints # right hand: wrist 4, elbow 3, shoulder 2 # left hand: wrist 7, elbow 6, shoulder 5 left_shoulder = keypoints[5] left_elbow = keypoints[6] left_wrist = keypoints[7] right_shoulder = keypoints[2] right_elbow = keypoints[3] right_wrist = keypoints[4] # if any of three not detected has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist)) has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) if not (has_left or has_right): return [] hands = [] #left hand if has_left: hands.append([ left_shoulder.x, left_shoulder.y, left_elbow.x, left_elbow.y, left_wrist.x, left_wrist.y, True ]) # right hand if has_right: hands.append([ right_shoulder.x, right_shoulder.y, right_elbow.x, right_elbow.y, right_wrist.x, right_wrist.y, False ]) for x1, y1, x2, y2, x3, y3, is_left in hands: # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); x = x3 + ratioWristElbow * (x3 - x2) y = y3 + ratioWristElbow * (y3 - y2) distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) # x-y refers to the center --> offset to topLeft point # handRectangle.x -= handRectangle.width / 2.f; # handRectangle.y -= handRectangle.height / 2.f; x -= width / 2 y -= width / 2 # width = height # overflow the image if x < 0: x = 0 if y < 0: y = 0 width1 = width width2 = width if x + width > image_width: width1 = image_width - x if y + width > image_height: width2 = image_height - y width = min(width1, width2) # the max hand box value is 20 pixels if width >= 20: detect_result.append((int(x), int(y), int(width), is_left)) ''' return value: [[x, y, w, True if left hand else False]]. width=height since the network require squared input. x, y is the coordinate of top left ''' return detect_result # Written by Lvmin def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: """ Detect the face in the input body pose keypoints and calculate the bounding box for the face. Args: body (BodyResult): A BodyResult object containing the detected body pose keypoints. oriImg (numpy.ndarray): A 3D numpy array representing the original input image. Returns: Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the bounding box and the width (height) of the bounding box, or None if the face is not detected or the bounding box width is less than 20 pixels. Notes: - The width and height of the bounding box are equal. - The minimum bounding box size is 20 pixels. """ # left right eye ear 14 15 16 17 image_height, image_width = oriImg.shape[0:2] keypoints = body.keypoints head = keypoints[0] left_eye = keypoints[14] right_eye = keypoints[15] left_ear = keypoints[16] right_ear = keypoints[17] if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): return None width = 0.0 x0, y0 = head.x, head.y if left_eye is not None: x1, y1 = left_eye.x, left_eye.y d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 3.0) if right_eye is not None: x1, y1 = right_eye.x, right_eye.y d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 3.0) if left_ear is not None: x1, y1 = left_ear.x, left_ear.y d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 1.5) if right_ear is not None: x1, y1 = right_ear.x, right_ear.y d = max(abs(x0 - x1), abs(y0 - y1)) width = max(width, d * 1.5) x, y = x0, y0 x -= width y -= width if x < 0: x = 0 if y < 0: y = 0 width1 = width * 2 width2 = width * 2 if x + width > image_width: width1 = image_width - x if y + width > image_height: width2 = image_height - y width = min(width1, width2) if width >= 20: return int(x), int(y), int(width) else: return None # get max index of 2d array def npmax(array): arrayindex = array.argmax(1) arrayvalue = array.max(1) i = arrayvalue.argmax() j = arrayindex[i] return i, j ================================================ FILE: modules/control/proc/pidi.py ================================================ import os import cv2 import numpy as np import torch from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, nms, resize_image, safe_step from .pidi_model import pidinet class PidiNetDetector: def __init__(self, model): self.model = model @classmethod def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False): filename = filename or "table5_pidinet.pth" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) model = pidinet() model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) model.eval() return cls(model) def to(self, device): self.model.to(device) return self def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=False, **kwargs): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 input_image = input_image[:, :, ::-1].copy() image_pidi = torch.from_numpy(input_image).float().to(device) image_pidi = image_pidi / 255.0 image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') edge = self.model(image_pidi)[-1] edge = edge.cpu().numpy() if apply_filter: edge = edge > 0.5 if safe: edge = safe_step(edge) edge = (edge * 255.0).clip(0, 255).astype(np.uint8) detected_map = edge[0, 0] detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if scribble: detected_map = nms(detected_map, 127, 3.0) detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) detected_map[detected_map > 4] = 255 detected_map[detected_map < 255] = 0 if opts.control_move_processor: self.model.to('cpu') if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/pidi_model.py ================================================ """ Author: Zhuo Su, Wenzhe Liu Date: Feb 18, 2021 """ import math import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def img2tensor(imgs, bgr2rgb=True, float32=True): """Numpy array to tensor. Args: imgs (list[ndarray] | ndarray): Input images. bgr2rgb (bool): Whether to change bgr to rgb. float32 (bool): Whether to change to float32. Returns: list[tensor] | tensor: Tensor images. If returned results only have one element, just return tensor. """ def _totensor(img, bgr2rgb, float32): if img.shape[2] == 3 and bgr2rgb: if img.dtype == 'float64': img = img.astype('float32') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img.transpose(2, 0, 1)) if float32: img = img.float() return img if isinstance(imgs, list): return [_totensor(img, bgr2rgb, float32) for img in imgs] else: return _totensor(imgs, bgr2rgb, float32) nets = { 'baseline': { 'layer0': 'cv', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'cv', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'cv', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'cv', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'c-v15': { 'layer0': 'cd', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'cv', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'cv', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'cv', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'a-v15': { 'layer0': 'ad', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'cv', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'cv', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'cv', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'r-v15': { 'layer0': 'rd', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'cv', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'cv', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'cv', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'cvvv4': { 'layer0': 'cd', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'cd', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'cd', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'cd', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'avvv4': { 'layer0': 'ad', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'ad', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'ad', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'ad', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'rvvv4': { 'layer0': 'rd', 'layer1': 'cv', 'layer2': 'cv', 'layer3': 'cv', 'layer4': 'rd', 'layer5': 'cv', 'layer6': 'cv', 'layer7': 'cv', 'layer8': 'rd', 'layer9': 'cv', 'layer10': 'cv', 'layer11': 'cv', 'layer12': 'rd', 'layer13': 'cv', 'layer14': 'cv', 'layer15': 'cv', }, 'cccv4': { 'layer0': 'cd', 'layer1': 'cd', 'layer2': 'cd', 'layer3': 'cv', 'layer4': 'cd', 'layer5': 'cd', 'layer6': 'cd', 'layer7': 'cv', 'layer8': 'cd', 'layer9': 'cd', 'layer10': 'cd', 'layer11': 'cv', 'layer12': 'cd', 'layer13': 'cd', 'layer14': 'cd', 'layer15': 'cv', }, 'aaav4': { 'layer0': 'ad', 'layer1': 'ad', 'layer2': 'ad', 'layer3': 'cv', 'layer4': 'ad', 'layer5': 'ad', 'layer6': 'ad', 'layer7': 'cv', 'layer8': 'ad', 'layer9': 'ad', 'layer10': 'ad', 'layer11': 'cv', 'layer12': 'ad', 'layer13': 'ad', 'layer14': 'ad', 'layer15': 'cv', }, 'rrrv4': { 'layer0': 'rd', 'layer1': 'rd', 'layer2': 'rd', 'layer3': 'cv', 'layer4': 'rd', 'layer5': 'rd', 'layer6': 'rd', 'layer7': 'cv', 'layer8': 'rd', 'layer9': 'rd', 'layer10': 'rd', 'layer11': 'cv', 'layer12': 'rd', 'layer13': 'rd', 'layer14': 'rd', 'layer15': 'cv', }, 'c16': { 'layer0': 'cd', 'layer1': 'cd', 'layer2': 'cd', 'layer3': 'cd', 'layer4': 'cd', 'layer5': 'cd', 'layer6': 'cd', 'layer7': 'cd', 'layer8': 'cd', 'layer9': 'cd', 'layer10': 'cd', 'layer11': 'cd', 'layer12': 'cd', 'layer13': 'cd', 'layer14': 'cd', 'layer15': 'cd', }, 'a16': { 'layer0': 'ad', 'layer1': 'ad', 'layer2': 'ad', 'layer3': 'ad', 'layer4': 'ad', 'layer5': 'ad', 'layer6': 'ad', 'layer7': 'ad', 'layer8': 'ad', 'layer9': 'ad', 'layer10': 'ad', 'layer11': 'ad', 'layer12': 'ad', 'layer13': 'ad', 'layer14': 'ad', 'layer15': 'ad', }, 'r16': { 'layer0': 'rd', 'layer1': 'rd', 'layer2': 'rd', 'layer3': 'rd', 'layer4': 'rd', 'layer5': 'rd', 'layer6': 'rd', 'layer7': 'rd', 'layer8': 'rd', 'layer9': 'rd', 'layer10': 'rd', 'layer11': 'rd', 'layer12': 'rd', 'layer13': 'rd', 'layer14': 'rd', 'layer15': 'rd', }, 'carv4': { 'layer0': 'cd', 'layer1': 'ad', 'layer2': 'rd', 'layer3': 'cv', 'layer4': 'cd', 'layer5': 'ad', 'layer6': 'rd', 'layer7': 'cv', 'layer8': 'cd', 'layer9': 'ad', 'layer10': 'rd', 'layer11': 'cv', 'layer12': 'cd', 'layer13': 'ad', 'layer14': 'rd', 'layer15': 'cv', }, } def createConvFunc(op_type): assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) if op_type == 'cv': return F.conv2d if op_type == 'cd': def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' assert padding == dilation, 'padding for cd_conv set wrong' weights_c = weights.sum(dim=[2, 3], keepdim=True) yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) return y - yc return func elif op_type == 'ad': def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' assert padding == dilation, 'padding for ad_conv set wrong' shape = weights.shape weights = weights.view(shape[0], shape[1], -1) weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) return y return func elif op_type == 'rd': def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' padding = 2 * dilation shape = weights.shape if weights.is_cuda: buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) else: buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device) weights = weights.view(shape[0], shape[1], -1) buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] buffer[:, :, 12] = 0 buffer = buffer.view(shape[0], shape[1], 5, 5) y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) return y return func else: print('impossible to be here unless you force that') return None class Conv2d(nn.Module): def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): super(Conv2d, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() self.pdc = pdc def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) nn.init.uniform_(self.bias, -bound, bound) def forward(self, input): return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) class CSAM(nn.Module): """ Compact Spatial Attention Module """ def __init__(self, channels): super(CSAM, self).__init__() mid_channels = 4 self.relu1 = nn.ReLU() self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) self.sigmoid = nn.Sigmoid() nn.init.constant_(self.conv1.bias, 0) def forward(self, x): y = self.relu1(x) y = self.conv1(y) y = self.conv2(y) y = self.sigmoid(y) return x * y class CDCM(nn.Module): """ Compact Dilation Convolution based Module """ def __init__(self, in_channels, out_channels): super(CDCM, self).__init__() self.relu1 = nn.ReLU() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) nn.init.constant_(self.conv1.bias, 0) def forward(self, x): x = self.relu1(x) x = self.conv1(x) x1 = self.conv2_1(x) x2 = self.conv2_2(x) x3 = self.conv2_3(x) x4 = self.conv2_4(x) return x1 + x2 + x3 + x4 class MapReduce(nn.Module): """ Reduce feature maps into a single edge map """ def __init__(self, channels): super(MapReduce, self).__init__() self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) nn.init.constant_(self.conv.bias, 0) def forward(self, x): return self.conv(x) class PDCBlock(nn.Module): def __init__(self, pdc, inplane, ouplane, stride=1): super(PDCBlock, self).__init__() self.stride=stride self.stride=stride if self.stride > 1: self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) self.relu2 = nn.ReLU() self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) def forward(self, x): if self.stride > 1: x = self.pool(x) y = self.conv1(x) y = self.relu2(y) y = self.conv2(y) if self.stride > 1: x = self.shortcut(x) y = y + x return y class PDCBlock_converted(nn.Module): """ CPDC, APDC can be converted to vanilla 3x3 convolution RPDC can be converted to vanilla 5x5 convolution """ def __init__(self, pdc, inplane, ouplane, stride=1): super(PDCBlock_converted, self).__init__() self.stride=stride if self.stride > 1: self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) if pdc == 'rd': self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) else: self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) self.relu2 = nn.ReLU() self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) def forward(self, x): if self.stride > 1: x = self.pool(x) y = self.conv1(x) y = self.relu2(y) y = self.conv2(y) if self.stride > 1: x = self.shortcut(x) y = y + x return y class PiDiNet(nn.Module): def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): super(PiDiNet, self).__init__() self.sa = sa if dil is not None: assert isinstance(dil, int), 'dil should be an int' self.dil = dil self.fuseplanes = [] self.inplane = inplane if convert: if pdcs[0] == 'rd': init_kernel_size = 5 init_padding = 2 else: init_kernel_size = 3 init_padding = 1 self.init_block = nn.Conv2d(3, self.inplane, kernel_size=init_kernel_size, padding=init_padding, bias=False) block_class = PDCBlock_converted else: self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) block_class = PDCBlock self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) self.fuseplanes.append(self.inplane) # C inplane = self.inplane self.inplane = self.inplane * 2 self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) self.fuseplanes.append(self.inplane) # 2C inplane = self.inplane self.inplane = self.inplane * 2 self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) self.fuseplanes.append(self.inplane) # 4C self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) self.fuseplanes.append(self.inplane) # 4C self.conv_reduces = nn.ModuleList() if self.sa and self.dil is not None: self.attentions = nn.ModuleList() self.dilations = nn.ModuleList() for i in range(4): self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) self.attentions.append(CSAM(self.dil)) self.conv_reduces.append(MapReduce(self.dil)) elif self.sa: self.attentions = nn.ModuleList() for i in range(4): self.attentions.append(CSAM(self.fuseplanes[i])) self.conv_reduces.append(MapReduce(self.fuseplanes[i])) elif self.dil is not None: self.dilations = nn.ModuleList() for i in range(4): self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) self.conv_reduces.append(MapReduce(self.dil)) else: for i in range(4): self.conv_reduces.append(MapReduce(self.fuseplanes[i])) self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias nn.init.constant_(self.classifier.weight, 0.25) nn.init.constant_(self.classifier.bias, 0) # print('initialization done') def get_weights(self): conv_weights = [] bn_weights = [] relu_weights = [] for pname, p in self.named_parameters(): if 'bn' in pname: bn_weights.append(p) elif 'relu' in pname: relu_weights.append(p) else: conv_weights.append(p) return conv_weights, bn_weights, relu_weights def forward(self, x): H, W = x.size()[2:] x = self.init_block(x) x1 = self.block1_1(x) x1 = self.block1_2(x1) x1 = self.block1_3(x1) x2 = self.block2_1(x1) x2 = self.block2_2(x2) x2 = self.block2_3(x2) x2 = self.block2_4(x2) x3 = self.block3_1(x2) x3 = self.block3_2(x3) x3 = self.block3_3(x3) x3 = self.block3_4(x3) x4 = self.block4_1(x3) x4 = self.block4_2(x4) x4 = self.block4_3(x4) x4 = self.block4_4(x4) x_fuses = [] if self.sa and self.dil is not None: for i, xi in enumerate([x1, x2, x3, x4]): x_fuses.append(self.attentions[i](self.dilations[i](xi))) elif self.sa: for i, xi in enumerate([x1, x2, x3, x4]): x_fuses.append(self.attentions[i](xi)) elif self.dil is not None: for i, xi in enumerate([x1, x2, x3, x4]): x_fuses.append(self.dilations[i](xi)) else: x_fuses = [x1, x2, x3, x4] e1 = self.conv_reduces[0](x_fuses[0]) e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) e2 = self.conv_reduces[1](x_fuses[1]) e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) e3 = self.conv_reduces[2](x_fuses[2]) e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) e4 = self.conv_reduces[3](x_fuses[3]) e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) outputs = [e1, e2, e3, e4] output = self.classifier(torch.cat(outputs, dim=1)) #if not self.training: # return torch.sigmoid(output) outputs.append(output) outputs = [torch.sigmoid(r) for r in outputs] return outputs def config_model(model): model_options = list(nets.keys()) assert model in model_options, \ 'unrecognized model, please choose from %s' % str(model_options) # print(str(nets[model])) pdcs = [] for i in range(16): layer_name = 'layer%d' % i op = nets[model][layer_name] pdcs.append(createConvFunc(op)) return pdcs def pidinet(): pdcs = config_model('carv4') dil = 24 #if args.dil else None return PiDiNet(60, pdcs, dil=dil, sa=True) if __name__ == '__main__': model = pidinet() ckp = torch.load('table5_pidinet.pth')['state_dict'] model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) im = cv2.imread('examples/test_my/cat_v4.png') im = img2tensor(im).unsqueeze(0)/255. res = model(im)[-1] res = res>0.5 res = res.float() res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8) print(res.shape) cv2.imwrite('edge.png', res) ================================================ FILE: modules/control/proc/segment_anything/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import warnings from typing import Union import cv2 import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image from .automatic_mask_generator import SamAutomaticMaskGenerator from .build_sam import sam_model_registry class SamDetector: def __init__(self, mask_generator: SamAutomaticMaskGenerator = None): self.model = mask_generator @classmethod def from_pretrained(cls, model_path, filename, model_type, cache_dir=None, local_files_only=False): """ Possible model_type : vit_h, vit_l, vit_b, vit_t download weights from https://github.com/facebookresearch/segment-anything """ model_path = hf_hub_download(model_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) sam = sam_model_registry[model_type](checkpoint=model_path) sam.to(devices.device) mask_generator = SamAutomaticMaskGenerator(sam) return cls(mask_generator) def show_anns(self, anns): from numpy.random import default_rng gen = default_rng() if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) h, w = anns[0]['segmentation'].shape final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") for ann in sorted_anns: m = ann['segmentation'] img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) for i in range(3): img[:,:,i] = gen.integers(255, dtype=np.uint8) final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255))) return np.array(final_img, dtype=np.uint8) def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> Image.Image: if "image" in kwargs: warnings.warn("image is deprecated, please use `input_image=...` instead.", DeprecationWarning) input_image = kwargs.pop("image") if input_image is None: raise ValueError("input_image must be defined.") if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) # Generate Masks self.model.predictor.model.to(devices.device) masks = self.model.generate(input_image) if opts.control_move_processor: self.model.predictor.model.to('cpu') # Create map image_map = self.show_anns(masks) detected_map = image_map detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/segment_anything/automatic_mask_generator.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from torchvision.ops.boxes import batched_nms, box_area # type: ignore from .modeling import Sam from .predictor import SamPredictor from .utils.amg import ( MaskData, area_from_rle, batch_iterator, batched_mask_to_box, box_xyxy_to_xywh, build_all_layer_point_grids, calculate_stability_score, coco_encode_rle, generate_crop_boxes, is_box_near_crop_edge, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, uncrop_masks, uncrop_points, ) class SamAutomaticMaskGenerator: def __init__( self, model: Sam, points_per_side: Optional[int] = 32, points_per_batch: int = 64, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0, box_nms_thresh: float = 0.7, crop_n_layers: int = 0, crop_nms_thresh: float = 0.7, crop_overlap_ratio: float = 512 / 1500, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[np.ndarray]] = None, min_mask_region_area: int = 0, output_mode: str = "binary_mask", ) -> None: """ Using a SAM model, generates masks for the entire image. Generates a grid of point prompts over the image, then filters low quality and duplicate masks. The default settings are chosen for SAM with a ViT-H backbone. Arguments: model (Sam): The SAM model to use for mask prediction. points_per_side (int or None): The number of points to be sampled along one side of the image. The total number of points is points_per_side**2. If None, 'point_grids' must provide explicit point sampling. points_per_batch (int): Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. pred_iou_thresh (float): A filtering threshold in [0,1], using the model's predicted mask quality. stability_score_thresh (float): A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions. stability_score_offset (float): The amount to shift the cutoff when calculated the stability score. box_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks. crop_n_layers (int): If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where each layer has 2**i_layer number of image crops. crop_nms_thresh (float): The box IoU cutoff used by non-maximal suppression to filter duplicate masks between different crops. crop_overlap_ratio (float): Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap. crop_n_points_downscale_factor (int): The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n. point_grids (list(np.ndarray) or None): A list over explicit grids of points used for sampling, normalized to [0,1]. The nth grid in the list is used in the nth crop layer. Exclusive with points_per_side. min_mask_region_area (int): If >0, postprocessing will be applied to remove disconnected regions and holes in masks with area smaller than min_mask_region_area. Requires opencv. output_mode (str): The form masks are returned in. Can be 'binary_mask', 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. For large resolutions, 'binary_mask' may consume large amounts of memory. """ assert (points_per_side is None) != ( point_grids is None ), "Exactly one of points_per_side or point_grid must be provided." if points_per_side is not None: self.point_grids = build_all_layer_point_grids( points_per_side, crop_n_layers, crop_n_points_downscale_factor, ) elif point_grids is not None: self.point_grids = point_grids else: raise ValueError("Can't have both points_per_side and point_grid be None.") assert output_mode in [ "binary_mask", "uncompressed_rle", "coco_rle", ], f"Unknown output_mode {output_mode}." self.predictor = SamPredictor(model) self.points_per_batch = points_per_batch self.pred_iou_thresh = pred_iou_thresh self.stability_score_thresh = stability_score_thresh self.stability_score_offset = stability_score_offset self.box_nms_thresh = box_nms_thresh self.crop_n_layers = crop_n_layers self.crop_nms_thresh = crop_nms_thresh self.crop_overlap_ratio = crop_overlap_ratio self.crop_n_points_downscale_factor = crop_n_points_downscale_factor self.min_mask_region_area = min_mask_region_area self.output_mode = output_mode def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: """ Generates masks for the given image. Arguments: image (np.ndarray): The image to generate masks for, in HWC uint8 format. Returns: list(dict(str, any)): A list over records for masks. Each record is a dict containing the following keys: segmentation (dict(str, any) or np.ndarray): The mask. If output_mode='binary_mask', is an array of shape HW. Otherwise, is a dictionary containing the RLE. bbox (list(float)): The box around the mask, in XYWH format. area (int): The area in pixels of the mask. predicted_iou (float): The model's own prediction of the mask's quality. This is filtered by the pred_iou_thresh parameter. point_coords (list(list(float))): The point coordinates input to the model to generate this mask. stability_score (float): A measure of the mask's quality. This is filtered on using the stability_score_thresh parameter. crop_box (list(float)): The crop of the image used to generate the mask, given in XYWH format. """ # Generate masks mask_data = self._generate_masks(image) # Filter small disconnected regions and holes in masks if self.min_mask_region_area > 0: mask_data = self.postprocess_small_regions( mask_data, self.min_mask_region_area, max(self.box_nms_thresh, self.crop_nms_thresh), ) # Encode masks if self.output_mode == "coco_rle": mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] elif self.output_mode == "binary_mask": mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] else: mask_data["segmentations"] = mask_data["rles"] # Write mask records curr_anns = [] for idx in range(len(mask_data["segmentations"])): ann = { "segmentation": mask_data["segmentations"][idx], "area": area_from_rle(mask_data["rles"][idx]), "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), "predicted_iou": mask_data["iou_preds"][idx].item(), "point_coords": [mask_data["points"][idx].tolist()], "stability_score": mask_data["stability_score"][idx].item(), "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), } curr_anns.append(ann) return curr_anns def _generate_masks(self, image: np.ndarray) -> MaskData: orig_size = image.shape[:2] crop_boxes, layer_idxs = generate_crop_boxes( orig_size, self.crop_n_layers, self.crop_overlap_ratio ) # Iterate over image crops data = MaskData() for crop_box, layer_idx in zip(crop_boxes, layer_idxs): crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) data.cat(crop_data) # Remove duplicate masks between crops if len(crop_boxes) > 1: # Prefer masks from smaller crops scores = 1 / box_area(data["crop_boxes"]) scores = scores.to(data["boxes"].device) keep_by_nms = batched_nms( data["boxes"].float(), scores, torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.crop_nms_thresh, ) data.filter(keep_by_nms) data.to_numpy() return data def _process_crop( self, image: np.ndarray, crop_box: List[int], crop_layer_idx: int, orig_size: Tuple[int, ...], ) -> MaskData: # Crop the image and calculate embeddings x0, y0, x1, y1 = crop_box cropped_im = image[y0:y1, x0:x1, :] cropped_im_size = cropped_im.shape[:2] self.predictor.set_image(cropped_im) # Get points for this crop points_scale = np.array(cropped_im_size)[None, ::-1] points_for_image = self.point_grids[crop_layer_idx] * points_scale # Generate masks for this crop in batches data = MaskData() for (points,) in batch_iterator(self.points_per_batch, points_for_image): batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) data.cat(batch_data) del batch_data self.predictor.reset_image() # Remove duplicates within this crop. keep_by_nms = batched_nms( data["boxes"].float(), data["iou_preds"], torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.box_nms_thresh, ) data.filter(keep_by_nms) # Return to the original image frame data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) data["points"] = uncrop_points(data["points"], crop_box) data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) return data def _process_batch( self, points: np.ndarray, im_size: Tuple[int, ...], crop_box: List[int], orig_size: Tuple[int, ...], ) -> MaskData: orig_h, orig_w = orig_size # Run model on this batch transformed_points = self.predictor.transform.apply_coords(points, im_size) in_points = torch.as_tensor(transformed_points, device=self.predictor.device) in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) masks, iou_preds, _ = self.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], multimask_output=True, return_logits=True, ) # Serialize predictions and store in MaskData data = MaskData( masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1), points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), ) del masks # Filter by predicted IoU if self.pred_iou_thresh > 0.0: keep_mask = data["iou_preds"] > self.pred_iou_thresh data.filter(keep_mask) # Calculate stability score data["stability_score"] = calculate_stability_score( data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset ) if self.stability_score_thresh > 0.0: keep_mask = data["stability_score"] >= self.stability_score_thresh data.filter(keep_mask) # Threshold masks and calculate boxes data["masks"] = data["masks"] > self.predictor.model.mask_threshold data["boxes"] = batched_mask_to_box(data["masks"]) # Filter boxes that touch crop boundaries keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) if not torch.all(keep_mask): data.filter(keep_mask) # Compress to RLE data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) data["rles"] = mask_to_rle_pytorch(data["masks"]) del data["masks"] return data @staticmethod def postprocess_small_regions( mask_data: MaskData, min_area: int, nms_thresh: float ) -> MaskData: """ Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates. Edits mask_data in place. Requires open-cv as a dependency. """ if len(mask_data["rles"]) == 0: return mask_data # Filter small disconnected regions and holes new_masks = [] scores = [] for rle in mask_data["rles"]: mask = rle_to_mask(rle) mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) # Give score=0 to changed masks and score=1 to unchanged masks # so NMS will prefer ones that didn't need postprocessing scores.append(float(unchanged)) # Recalculate boxes and remove any new duplicates masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores), torch.zeros_like(boxes[:, 0]), # categories iou_threshold=nms_thresh, ) # Only recalculate RLEs for masks that have changed for i_mask in keep_by_nms: if scores[i_mask] == 0.0: mask_torch = masks[i_mask].unsqueeze(0) mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly mask_data.filter(keep_by_nms) return mask_data ================================================ FILE: modules/control/proc/segment_anything/build_sam.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from functools import partial from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT def build_sam_vit_h(checkpoint=None): return _build_sam( encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, encoder_global_attn_indexes=[7, 15, 23, 31], checkpoint=checkpoint, ) build_sam = build_sam_vit_h def build_sam_vit_l(checkpoint=None): return _build_sam( encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, encoder_global_attn_indexes=[5, 11, 17, 23], checkpoint=checkpoint, ) def build_sam_vit_b(checkpoint=None): return _build_sam( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, ) def build_sam_vit_t(checkpoint=None): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size mobile_sam = Sam( image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, embed_dims=[64, 128, 160, 320], depths=[2, 2, 6, 2], num_heads=[2, 4, 5, 10], window_sizes=[7, 7, 14, 7], mlp_ratio=4., drop_rate=0., drop_path_rate=0.0, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, layer_lr_decay=0.8 ), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, ), pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) mobile_sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f) mobile_sam.load_state_dict(state_dict) return mobile_sam sam_model_registry = { "default": build_sam_vit_h, "vit_h": build_sam_vit_h, "vit_l": build_sam_vit_l, "vit_b": build_sam_vit_b, "vit_t": build_sam_vit_t, } def _build_sam( encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, ): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size sam = Sam( image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ), prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ), mask_decoder=MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, ), pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f) sam.load_state_dict(state_dict) return sam ================================================ FILE: modules/control/proc/segment_anything/modeling/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from .sam import Sam from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder from .transformer import TwoWayTransformer from .tiny_vit_sam import TinyViT ================================================ FILE: modules/control/proc/segment_anything/modeling/common.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from typing import Type class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x ================================================ FILE: modules/control/proc/segment_anything/modeling/image_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Type from .common import LayerNorm2d, MLPBlock # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) return x class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x ================================================ FILE: modules/control/proc/segment_anything/modeling/mask_decoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from torch.nn import functional as F from typing import List, Tuple, Type from .common import LayerNorm2d class MaskDecoder(nn.Module): def __init__( self, *, transformer_dim: int, transformer: nn.Module, num_multimask_outputs: int = 3, activation: Type[nn.Module] = nn.GELU, iou_head_depth: int = 3, iou_head_hidden_dim: int = 256, ) -> None: """ Predicts masks given an image and prompt embeddings, using a transformer architecture. Arguments: transformer_dim (int): the channel dimension of the transformer transformer (nn.Module): the transformer used to predict masks num_multimask_outputs (int): the number of masks to predict when disambiguating masks activation (nn.Module): the type of activation to use when upscaling masks iou_head_depth (int): the depth of the MLP used to predict mask quality iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality """ super().__init__() self.transformer_dim = transformer_dim self.transformer = transformer self.num_multimask_outputs = num_multimask_outputs self.iou_token = nn.Embedding(1, transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) self.output_upscaling = nn.Sequential( nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), LayerNorm2d(transformer_dim // 4), activation(), nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), activation(), ) self.output_hypernetworks_mlps = nn.ModuleList( [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens) ] ) self.iou_prediction_head = MLP( transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth ) def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ masks, iou_pred = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, ) # Select the correct mask or masks for output if multimask_output: mask_slice = slice(1, None) else: mask_slice = slice(0, 1) masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice] # Prepare output return masks, iou_pred def predict_masks( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred # Lightly adapted from # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x ================================================ FILE: modules/control/proc/segment_anything/modeling/prompt_encoder.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch from torch import nn from typing import Any, Optional, Tuple, Type from .common import LayerNorm2d class PromptEncoder(nn.Module): def __init__( self, embed_dim: int, image_embedding_size: Tuple[int, int], input_image_size: Tuple[int, int], mask_in_chans: int, activation: Type[nn.Module] = nn.GELU, ) -> None: """ Encodes prompts for input to SAM's mask decoder. Arguments: embed_dim (int): The prompts' embedding dimension image_embedding_size (tuple(int, int)): The spatial size of the image embedding, as (H, W). input_image_size (int): The padded size of the image as input to the image encoder, as (H, W). mask_in_chans (int): The number of hidden channels used for encoding input masks. activation (nn.Module): The activation to use when encoding input masks. """ super().__init__() self.embed_dim = embed_dim self.input_image_size = input_image_size self.image_embedding_size = image_embedding_size self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) self.mask_downscaling = nn.Sequential( nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans // 4), activation(), nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans), activation(), nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), ) self.no_mask_embed = nn.Embedding(1, embed_dim) def get_dense_pe(self) -> torch.Tensor: """ Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the image encoding. Returns: torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) def _embed_points( self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: """Embeds mask inputs.""" mask_embedding = self.mask_downscaling(masks) return mask_embedding def _get_batch_size( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> int: """ Gets the batch size of the output given the batch size of the input prompts. """ if points is not None: return points[0].shape[0] elif boxes is not None: return boxes.shape[0] elif masks is not None: return masks.shape[0] else: return 1 def _get_device(self) -> torch.device: return self.point_embeddings[0].weight.device def forward( self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Embeds different types of prompts, returning both sparse and dense embeddings. Arguments: points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates and labels to embed. boxes (torch.Tensor or none): boxes to embed masks (torch.Tensor or none): masks to embed Returns: torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined by the number of input points and boxes. torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] ) return sparse_embeddings, dense_embeddings class PositionEmbeddingRandom(nn.Module): """ Positional encoding using random spatial frequencies. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.register_buffer( "positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)), ) def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) def forward(self, size: Tuple[int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w = size device: Any = self.positional_encoding_gaussian_matrix.device grid = torch.ones((h, w), device=device, dtype=torch.float32) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 y_embed = y_embed / h x_embed = x_embed / w pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) return pe.permute(2, 0, 1) # C x H x W def forward_with_coords( self, coords_input: torch.Tensor, image_size: Tuple[int, int] ) -> torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[1] coords[:, :, 1] = coords[:, :, 1] / image_size[0] return self._pe_encoding(coords.to(torch.float)) # B x N x C ================================================ FILE: modules/control/proc/segment_anything/modeling/sam.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from torch.nn import functional as F from typing import Any, Dict, List, Tuple, Union from .tiny_vit_sam import TinyViT from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder class Sam(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" def __init__( self, image_encoder: Union[ImageEncoderViT, TinyViT], prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = None, pixel_std: List[float] = None, ) -> None: """ SAM predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ if pixel_std is None: pixel_std = [58.395, 57.12, 57.375] if pixel_mean is None: pixel_mean = [123.675, 116.28, 103.53] super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property def device(self) -> Any: return self.pixel_mean.device def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool, ) -> List[Dict[str, torch.Tensor]]: """ Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using SamPredictor is recommended over calling the model directly. Arguments: batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt key can be excluded if it is not present. 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model. 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W). 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already transformed to the input frame of the model. 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of the model. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW. multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single mask. Returns: (list(dict)): A list over input images, where each element is as dictionary with the following keys. 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC. 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction. """ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) image_embeddings = self.image_encoder(input_images) outputs = [] for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) masks = self.postprocess_masks( low_res_masks, input_size=image_record["image"].shape[-2:], original_size=image_record["original_size"], ) masks = masks > self.mask_threshold outputs.append( { "masks": masks, "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, } ) return outputs def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. Arguments: masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. original_size (tuple(int, int)): The original size of the image before resizing for input to the model, in (H, W) format. Returns: (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) masks = masks[..., : input_size[0], : input_size[1]] masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - self.pixel_mean) / self.pixel_std # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x ================================================ FILE: modules/control/proc/segment_anything/modeling/tiny_vit_sam.py ================================================ # -------------------------------------------------------- # TinyViT Model Architecture # Copyright (c) 2022 Microsoft # Adapted from LeViT and Swin Transformer # LeViT: (https://github.com/facebookresearch/levit) # Swin: (https://github.com/microsoft/swin-transformer) # Build the TinyViT Model # -------------------------------------------------------- import itertools import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath as TimmDropPath,\ to_2tuple, trunc_normal_ from timm.models.registry import register_model from typing import Tuple class Conv2d_BN(torch.nn.Sequential): def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() self.add_module('c', torch.nn.Conv2d( a, b, ks, stride, pad, dilation, groups, bias=False)) bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(bn.weight, bn_weight_init) torch.nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) def fuse(self): c, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps)**0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m class DropPath(TimmDropPath): def __init__(self, drop_prob=None): super().__init__(drop_prob=drop_prob) self.drop_prob = drop_prob def __repr__(self): msg = super().__repr__() msg += f'(drop_prob={self.drop_prob})' return msg class PatchEmbed(nn.Module): def __init__(self, in_chans, embed_dim, resolution, activation): super().__init__() img_size: Tuple[int, int] = to_2tuple(resolution) self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) self.num_patches = self.patches_resolution[0] * \ self.patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim n = embed_dim self.seq = nn.Sequential( Conv2d_BN(in_chans, n // 2, 3, 2, 1), activation(), Conv2d_BN(n // 2, n, 3, 2, 1), ) def forward(self, x): return self.seq(x) class MBConv(nn.Module): def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): super().__init__() self.in_chans = in_chans self.hidden_chans = int(in_chans * expand_ratio) self.out_chans = out_chans self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) self.act1 = activation() self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans) self.act2 = activation() self.conv3 = Conv2d_BN( self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) self.act3 = activation() self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x x = self.conv1(x) x = self.act1(x) x = self.conv2(x) x = self.act2(x) x = self.conv3(x) x = self.drop_path(x) x += shortcut x = self.act3(x) return x class PatchMerging(nn.Module): def __init__(self, input_resolution, dim, out_dim, activation): super().__init__() self.input_resolution = input_resolution self.dim = dim self.out_dim = out_dim self.act = activation() self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) stride_c=2 if(out_dim==320 or out_dim==448 or out_dim==576): stride_c=1 self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) def forward(self, x): if x.ndim == 3: H, W = self.input_resolution B = len(x) # (B, C, H, W) x = x.view(B, H, W, -1).permute(0, 3, 1, 2) x = self.conv1(x) x = self.act(x) x = self.conv2(x) x = self.act(x) x = self.conv3(x) x = x.flatten(2).transpose(1, 2) return x class ConvLayer(nn.Module): def __init__(self, dim, input_resolution, depth, activation, drop_path=0., downsample=None, use_checkpoint=False, out_dim=None, conv_expand_ratio=4., ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ MBConv(dim, dim, conv_expand_ratio, activation, drop_path[i] if isinstance(drop_path, list) else drop_path, ) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, out_dim=out_dim, activation=activation) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.norm = nn.LayerNorm(in_features) self.fc1 = nn.Linear(in_features, hidden_features) self.fc2 = nn.Linear(hidden_features, out_features) self.act = act_layer() self.drop = nn.Dropout(drop) def forward(self, x): x = self.norm(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(torch.nn.Module): def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14), ): super().__init__() # (h, w) assert isinstance(resolution, tuple) and len(resolution) == 2 self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio h = self.dh + nh_kd * 2 self.norm = nn.LayerNorm(dim) self.qkv = nn.Linear(dim, h) self.proj = nn.Linear(self.dh, dim) points = list(itertools.product( range(resolution[0]), range(resolution[1]))) N = len(points) attention_offsets = {} idxs = [] for p1 in points: for p2 in points: offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) self.attention_biases = torch.nn.Parameter( torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) def train(self, mode=True): super().train(mode) if mode and hasattr(self, 'ab'): del self.ab else: self.ab = self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): # x (B,N,C) B, N, _ = x.shape # Normalization x = self.norm(x) qkv = self.qkv(x) # (B, N, num_heads, d) q, k, v = qkv.view(B, N, self.num_heads, - 1).split([self.key_dim, self.key_dim, self.d], dim=3) # (B, num_heads, N, d) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) attn = ( (q @ k.transpose(-2, -1)) * self.scale + (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab) ) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x class TinyViTBlock(nn.Module): r""" TinyViT Block. Args: dim (int): Number of input channels. input_resolution (tuple[int, int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3 activation: the activation function. Default: nn.GELU """ def __init__(self, dim, input_resolution, num_heads, window_size=7, mlp_ratio=4., drop=0., drop_path=0., local_conv_size=3, activation=nn.GELU, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads assert window_size > 0, 'window_size must be greater than 0' self.window_size = window_size self.mlp_ratio = mlp_ratio self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() assert dim % num_heads == 0, 'dim must be divisible by num_heads' head_dim = dim // num_heads window_resolution = (window_size, window_size) self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) mlp_hidden_dim = int(dim * mlp_ratio) mlp_activation = activation self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop) pad = local_conv_size // 2 self.local_conv = Conv2d_BN( dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" res_x = x if H == self.window_size and W == self.window_size: x = self.attn(x) else: x = x.view(B, H, W, C) pad_b = (self.window_size - H % self.window_size) % self.window_size pad_r = (self.window_size - W % self.window_size) % self.window_size padding = pad_b > 0 or pad_r > 0 if padding: x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) pH, pW = H + pad_b, W + pad_r nH = pH // self.window_size nW = pW // self.window_size # window partition x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( B * nH * nW, self.window_size * self.window_size, C) x = self.attn(x) # window reverse x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) if padding: x = x[:, :H, :W].contiguous() x = x.view(B, L, C) x = res_x + self.drop_path(x) x = x.transpose(1, 2).reshape(B, C, H, W) x = self.local_conv(x) x = x.view(B, C, L).transpose(1, 2) x = x + self.drop_path(self.mlp(x)) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" class BasicLayer(nn.Module): """ A basic TinyViT layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. drop (float, optional): Dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 activation: the activation function. Default: nn.GELU out_dim: the output dimension of the layer. Default: dim """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., drop=0., drop_path=0., downsample=None, use_checkpoint=False, local_conv_size=3, activation=nn.GELU, out_dim=None, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ TinyViTBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i] if isinstance( drop_path, list) else drop_path, local_conv_size=local_conv_size, activation=activation, ) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample( input_resolution, dim=dim, out_dim=out_dim, activation=activation) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class TinyViT(nn.Module): def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=None, depths=None, num_heads=None, window_sizes=None, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, layer_lr_decay=1.0, ): if window_sizes is None: window_sizes = [7, 7, 14, 7] if num_heads is None: num_heads = [3, 6, 12, 24] if depths is None: depths = [2, 2, 6, 2] if embed_dims is None: embed_dims = [96, 192, 384, 768] super().__init__() self.img_size=img_size self.num_classes = num_classes self.depths = depths self.num_layers = len(depths) self.mlp_ratio = mlp_ratio activation = nn.GELU self.patch_embed = PatchEmbed(in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation) patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): kwargs = dict(dim=embed_dims[i_layer], input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), # input_resolution=(patches_resolution[0] // (2 ** i_layer), # patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], downsample=PatchMerging if ( i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, out_dim=embed_dims[min( i_layer + 1, len(embed_dims) - 1)], activation=activation, ) if i_layer == 0: layer = ConvLayer( conv_expand_ratio=mbconv_expand_ratio, **kwargs, ) else: layer = BasicLayer( num_heads=num_heads[i_layer], window_size=window_sizes[i_layer], mlp_ratio=self.mlp_ratio, drop=drop_rate, local_conv_size=local_conv_size, **kwargs) self.layers.append(layer) # Classifier head self.norm_head = nn.LayerNorm(embed_dims[-1]) self.head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() # init weights self.apply(self._init_weights) self.set_layer_lr_decay(layer_lr_decay) self.neck = nn.Sequential( nn.Conv2d( embed_dims[-1], 256, kernel_size=1, bias=False, ), LayerNorm2d(256), nn.Conv2d( 256, 256, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(256), ) def set_layer_lr_decay(self, layer_lr_decay): decay_rate = layer_lr_decay # layers -> blocks (depth) depth = sum(self.depths) lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] #print("LR SCALES:", lr_scales) def _set_lr_scale(m, scale): for p in m.parameters(): p.lr_scale = scale self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) i = 0 for layer in self.layers: for block in layer.blocks: block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) # noqa i += 1 if layer.downsample is not None: layer.downsample.apply( lambda x: _set_lr_scale(x, lr_scales[i - 1])) # noqa assert i == depth for m in [self.norm_head, self.head]: m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) for k, p in self.named_parameters(): p.param_name = k def _check_lr_scale(m): for p in m.parameters(): assert hasattr(p, 'lr_scale'), p.param_name self.apply(_check_lr_scale) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay_keywords(self): return {'attention_biases'} def forward_features(self, x): # x: (N, C, H, W) x = self.patch_embed(x) x = self.layers[0](x) start_i = 1 for i in range(start_i, len(self.layers)): layer = self.layers[i] x = layer(x) B,_,C=x.size() x = x.view(B, 64, 64, C) x=x.permute(0, 3, 1, 2) x=self.neck(x) return x def forward(self, x): x = self.forward_features(x) #x = self.norm_head(x) #x = self.head(x) return x _checkpoint_url_format = \ 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' _provided_checkpoints = { 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', } def register_tiny_vit_model(fn): '''Register a TinyViT model It is a wrapper of `register_model` with loading the pretrained checkpoint. ''' def fn_wrapper(pretrained=False, **kwargs): model = fn() if pretrained: model_name = fn.__name__ assert model_name in _provided_checkpoints, \ f'Sorry that the checkpoint `{model_name}` is not provided yet.' url = _checkpoint_url_format.format( _provided_checkpoints[model_name]) checkpoint = torch.hub.load_state_dict_from_url( url=url, map_location='cpu', check_hash=False, ) model.load_state_dict(checkpoint['model']) return model # rename the name of fn_wrapper fn_wrapper.__name__ = fn.__name__ return register_model(fn_wrapper) @register_tiny_vit_model def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): return TinyViT( num_classes=num_classes, embed_dims=[64, 128, 160, 320], depths=[2, 2, 6, 2], num_heads=[2, 4, 5, 10], window_sizes=[7, 7, 14, 7], drop_path_rate=drop_path_rate, ) @register_tiny_vit_model def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): return TinyViT( num_classes=num_classes, embed_dims=[64, 128, 256, 448], depths=[2, 2, 6, 2], num_heads=[2, 4, 8, 14], window_sizes=[7, 7, 14, 7], drop_path_rate=drop_path_rate, ) @register_tiny_vit_model def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): return TinyViT( num_classes=num_classes, embed_dims=[96, 192, 384, 576], depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 18], window_sizes=[7, 7, 14, 7], drop_path_rate=drop_path_rate, ) @register_tiny_vit_model def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): return TinyViT( img_size=384, num_classes=num_classes, embed_dims=[96, 192, 384, 576], depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 18], window_sizes=[12, 12, 24, 12], drop_path_rate=drop_path_rate, ) @register_tiny_vit_model def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): return TinyViT( img_size=512, num_classes=num_classes, embed_dims=[96, 192, 384, 576], depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 18], window_sizes=[16, 16, 32, 16], drop_path_rate=drop_path_rate, ) ================================================ FILE: modules/control/proc/segment_anything/modeling/transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import Tensor, nn import math from typing import Tuple, Type from .common import MLPBlock class TwoWayTransformer(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() for i in range(depth): self.layers.append( TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), ) ) self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) def forward( self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, h, w = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding k = keys + image_pe attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) return queries, keys class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys class Attention(nn.Module): """ An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and values. """ def __init__( self, embedding_dim: int, num_heads: int, downsample_rate: int = 1, ) -> None: super().__init__() self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) self.v_proj = nn.Linear(embedding_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head def _recombine_heads(self, x: Tensor) -> Tensor: b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: # Input projections q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) # Separate into heads q = self._separate_heads(q, self.num_heads) k = self._separate_heads(k, self.num_heads) v = self._separate_heads(v, self.num_heads) # Attention _, _, _, c_per_head = q.shape attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens attn = attn / math.sqrt(c_per_head) attn = torch.softmax(attn, dim=-1) # Get output out = attn @ v out = self._recombine_heads(out) out = self.out_proj(out) return out ================================================ FILE: modules/control/proc/segment_anything/predictor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Tuple import numpy as np import torch from .modeling import Sam from .utils.transforms import ResizeLongestSide class SamPredictor: def __init__( self, sam_model: Sam, ) -> None: """ Uses SAM to calculate the image embedding for an image, and then allow repeated, efficient mask prediction given prompts. Arguments: sam_model (Sam): The model to use for mask prediction. """ super().__init__() self.model = sam_model self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) self.reset_image() def set_image( self, image: np.ndarray, image_format: str = "RGB", ) -> None: """ Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. Arguments: image (np.ndarray): The image for calculating masks. Expects an image in HWC uint8 format, with pixel values in [0, 255]. image_format (str): The color format of the image, in ['RGB', 'BGR']. """ assert image_format in [ "RGB", "BGR", ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." if image_format != self.model.image_format: image = image[..., ::-1] # Transform the image to the form expected by the model input_image = self.transform.apply_image(image) input_image_torch = torch.as_tensor(input_image, device=self.device) input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] self.set_torch_image(input_image_torch, image.shape[:2]) def set_torch_image( self, transformed_image: torch.Tensor, original_image_size: Tuple[int, ...], ) -> None: """ Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. Expects the input image to be already transformed to the format expected by the model. Arguments: transformed_image (torch.Tensor): The input image, with shape 1x3xHxW, which has been transformed with ResizeLongestSide. original_image_size (tuple(int, int)): The size of the image before transformation, in (H, W) format. """ assert ( len(transformed_image.shape) == 4 and transformed_image.shape[1] == 3 and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." self.reset_image() self.original_size = original_image_size self.input_size = tuple(transformed_image.shape[-2:]) input_image = self.model.preprocess(transformed_image) self.features = self.model.image_encoder(input_image) self.is_image_set = True def predict( self, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None, mask_input: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. Arguments: point_coords (np.ndarray or None): A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (np.ndarray or None): A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (np.ndarray): The output masks in CxHxW format, where C is the number of masks, and (H, W) is the original image size. (np.ndarray): An array of length C containing the model's predictions for the quality of each mask. (np.ndarray): An array of shape CxHxW, where C is the number of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ if not self.is_image_set: raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") # Transform input prompts coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None if point_coords is not None: assert ( point_labels is not None ), "point_labels must be supplied if point_coords is supplied." point_coords = self.transform.apply_coords(point_coords, self.original_size) coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] if box is not None: box = self.transform.apply_boxes(box, self.original_size) box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) box_torch = box_torch[None, :] if mask_input is not None: mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) mask_input_torch = mask_input_torch[None, :, :, :] masks, iou_predictions, low_res_masks = self.predict_torch( coords_torch, labels_torch, box_torch, mask_input_torch, multimask_output, return_logits=return_logits, ) masks_np = masks[0].detach().cpu().numpy() iou_predictions_np = iou_predictions[0].detach().cpu().numpy() low_res_masks_np = low_res_masks[0].detach().cpu().numpy() return masks_np, iou_predictions_np, low_res_masks_np def predict_torch( self, point_coords: Optional[torch.Tensor], point_labels: Optional[torch.Tensor], boxes: Optional[torch.Tensor] = None, mask_input: Optional[torch.Tensor] = None, multimask_output: bool = True, return_logits: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. Input prompts are batched torch tensors and are expected to already be transformed to the input frame using ResizeLongestSide. Arguments: point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (torch.Tensor or None): A BxN array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. boxes (np.ndarray or None): A Bx4 array given a box prompt to the model, in XYXY format. mask_input (np.ndarray): A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form Bx1xHxW, where for SAM, H=W=256. Masks returned by a previous iteration of the predict method do not need further transformation. multimask_output (bool): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. Returns: (torch.Tensor): The output masks in BxCxHxW format, where C is the number of masks, and (H, W) is the original image size. (torch.Tensor): An array of shape BxC containing the model's predictions for the quality of each mask. (torch.Tensor): An array of shape BxCxHxW, where C is the number of masks and H=W=256. These low res logits can be passed to a subsequent iteration as mask input. """ if not self.is_image_set: raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") if point_coords is not None: points = (point_coords, point_labels) else: points = None # Embed prompts sparse_embeddings, dense_embeddings = self.model.prompt_encoder( points=points, boxes=boxes, masks=mask_input, ) # Predict masks low_res_masks, iou_predictions = self.model.mask_decoder( image_embeddings=self.features, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) # Upscale the masks to the original image resolution masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) if not return_logits: masks = masks > self.model.mask_threshold return masks, iou_predictions, low_res_masks def get_image_embedding(self) -> torch.Tensor: """ Returns the image embeddings for the currently set image, with shape 1xCxHxW, where C is the embedding dimension and (H,W) are the embedding spatial dimension of SAM (typically C=256, H=W=64). """ if not self.is_image_set: raise RuntimeError( "An image must be set with .set_image(...) to generate an embedding." ) assert self.features is not None, "Features must exist if an image has been set." return self.features @property def device(self) -> torch.device: return self.model.device def reset_image(self) -> None: """Resets the currently set image.""" self.is_image_set = False self.features = None self.orig_h = None self.orig_w = None self.input_h = None self.input_w = None ================================================ FILE: modules/control/proc/segment_anything/utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: modules/control/proc/segment_anything/utils/amg.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import math from copy import deepcopy from itertools import product from typing import Any, Dict, Generator, ItemsView, List, Tuple class MaskData: """ A structure for storing masks and their related data in batched format. Implements basic filtering and concatenation. """ def __init__(self, **kwargs) -> None: for v in kwargs.values(): assert isinstance( v, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats = dict(**kwargs) def __setitem__(self, key: str, item: Any) -> None: assert isinstance( item, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats[key] = item def __delitem__(self, key: str) -> None: del self._stats[key] def __getitem__(self, key: str) -> Any: return self._stats[key] def items(self) -> ItemsView[str, Any]: return self._stats.items() def filter(self, keep: torch.Tensor) -> None: for k, v in self._stats.items(): if v is None: self._stats[k] = None elif isinstance(v, torch.Tensor): self._stats[k] = v[torch.as_tensor(keep, device=v.device)] elif isinstance(v, np.ndarray): self._stats[k] = v[keep.detach().cpu().numpy()] elif isinstance(v, list) and keep.dtype == torch.bool: self._stats[k] = [a for i, a in enumerate(v) if keep[i]] elif isinstance(v, list): self._stats[k] = [v[i] for i in keep] else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def cat(self, new_stats: "MaskData") -> None: for k, v in new_stats.items(): if k not in self._stats or self._stats[k] is None: self._stats[k] = deepcopy(v) elif isinstance(v, torch.Tensor): self._stats[k] = torch.cat([self._stats[k], v], dim=0) elif isinstance(v, np.ndarray): self._stats[k] = np.concatenate([self._stats[k], v], axis=0) elif isinstance(v, list): self._stats[k] = self._stats[k] + deepcopy(v) else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def to_numpy(self) -> None: for k, v in self._stats.items(): if isinstance(v, torch.Tensor): self._stats[k] = v.detach().cpu().numpy() def is_box_near_crop_edge( boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 ) -> torch.Tensor: """Filter masks at the edge of a crop, but not at the edge of the original image.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) boxes = uncrop_boxes_xyxy(boxes, crop_box).float() near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) return torch.any(near_crop_edge, dim=1) def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: box_xywh = deepcopy(box_xyxy) box_xywh[2] = box_xywh[2] - box_xywh[0] box_xywh[3] = box_xywh[3] - box_xywh[1] return box_xywh def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: assert len(args) > 0 and all( len(a) == len(args[0]) for a in args ), "Batched iteration must have inputs of all the same size." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: """ Encodes masks to an uncompressed RLE, in the format expected by pycoco tools. """ # Put in fortran order and flatten h,w b, h, w = tensor.shape tensor = tensor.permute(0, 2, 1).flatten(1) # Compute change indices diff = tensor[:, 1:] ^ tensor[:, :-1] change_indices = diff.nonzero() # Encode run length out = [] for i in range(b): cur_idxs = change_indices[change_indices[:, 0] == i, 1] cur_idxs = torch.cat( [ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), cur_idxs + 1, torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), ] ) btw_idxs = cur_idxs[1:] - cur_idxs[:-1] counts = [] if tensor[i, 0] == 0 else [0] counts.extend(btw_idxs.detach().cpu().tolist()) out.append({"size": [h, w], "counts": counts}) return out def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" h, w = rle["size"] mask = np.empty(h * w, dtype=bool) idx = 0 parity = False for count in rle["counts"]: mask[idx : idx + count] = parity idx += count parity ^= True mask = mask.reshape(w, h) return mask.transpose() # Put in C order def area_from_rle(rle: Dict[str, Any]) -> int: return sum(rle["counts"][1::2]) def calculate_stability_score( masks: torch.Tensor, mask_threshold: float, threshold_offset: float ) -> torch.Tensor: """ Computes the stability score for a batch of masks. The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high and low values. """ # One mask is always contained inside the other. # Save memory by preventing unnecessary cast to torch.int64 intersections = ( (masks > (mask_threshold + threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) ) unions = ( (masks > (mask_threshold - threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) ) return intersections / unions def build_point_grid(n_per_side: int) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" offset = 1 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) points_y = np.tile(points_one_side[:, None], (1, n_per_side)) points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) return points def build_all_layer_point_grids( n_per_side: int, n_layers: int, scale_per_layer: int ) -> List[np.ndarray]: """Generates point grids for all crop layers.""" points_by_layer = [] for i in range(n_layers + 1): n_points = int(n_per_side / (scale_per_layer**i)) points_by_layer.append(build_point_grid(n_points)) return points_by_layer def generate_crop_boxes( im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float ) -> Tuple[List[List[int]], List[int]]: """ Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer. """ crop_boxes, layer_idxs = [], [] im_h, im_w = im_size short_side = min(im_h, im_w) # Original image crop_boxes.append([0, 0, im_w, im_h]) layer_idxs.append(0) def crop_len(orig_len, n_crops, overlap): return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) for i_layer in range(n_layers): n_crops_per_side = 2 ** (i_layer + 1) overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) crop_w = crop_len(im_w, n_crops_per_side, overlap) crop_h = crop_len(im_h, n_crops_per_side, overlap) crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] # Crops in XYWH format for x0, y0 in product(crop_box_x0, crop_box_y0): box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] crop_boxes.append(box) layer_idxs.append(i_layer + 1) return crop_boxes, layer_idxs def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) # Check if boxes has a channel dimension if len(boxes.shape) == 3: offset = offset.unsqueeze(1) return boxes + offset def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0]], device=points.device) # Check if points has a channel dimension if len(points.shape) == 3: offset = offset.unsqueeze(1) return points + offset def uncrop_masks( masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int ) -> torch.Tensor: x0, y0, x1, y1 = crop_box if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: return masks # Coordinate transform masks pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) pad = (x0, pad_x - x0, y0, pad_y - y0) return torch.nn.functional.pad(masks, pad, value=0) def remove_small_regions( mask: np.ndarray, area_thresh: float, mode: str ) -> Tuple[np.ndarray, bool]: """ Removes small disconnected regions and holes in a mask. Returns the mask and an indicator of if the mask has been modified. """ import cv2 # type: ignore assert mode in ["holes", "islands"] correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] if len(small_regions) == 0: return mask, False fill_labels = [0] + small_regions if not correct_holes: fill_labels = [i for i in range(n_labels) if i not in fill_labels] # If every region is below threshold, keep largest if len(fill_labels) == 0: fill_labels = [int(np.argmax(sizes)) + 1] mask = np.isin(regions, fill_labels) return mask, True def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: from pycocotools import mask as mask_utils # type: ignore h, w = uncompressed_rle["size"] rle = mask_utils.frPyObjects(uncompressed_rle, h, w) rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json return rle def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: """ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. """ # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: return torch.zeros(*masks.shape[:-2], 4, device=masks.device) # Normalize shape to CxHxW shape = masks.shape h, w = shape[-2:] if len(shape) > 2: masks = masks.flatten(0, -3) else: masks = masks.unsqueeze(0) # Get top and bottom edges in_height, _ = torch.max(masks, dim=-1) in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] bottom_edges, _ = torch.max(in_height_coords, dim=-1) in_height_coords = in_height_coords + h * (~in_height) top_edges, _ = torch.min(in_height_coords, dim=-1) # Get left and right edges in_width, _ = torch.max(masks, dim=-2) in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] right_edges, _ = torch.max(in_width_coords, dim=-1) in_width_coords = in_width_coords + w * (~in_width) left_edges, _ = torch.min(in_width_coords, dim=-1) # If the mask is empty the right edge will be to the left of the left edge. # Replace these boxes with [0, 0, 0, 0] empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) out = out * (~empty_filter).unsqueeze(-1) # Return to original shape if len(shape) > 2: out = out.reshape(*shape[:-2], 4) else: out = out[0] return out ================================================ FILE: modules/control/proc/segment_anything/utils/onnx.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch.nn import functional as F from typing import Tuple from ..modeling import Sam from .amg import calculate_stability_score class SamOnnxModel(nn.Module): """ This model should not be called directly, but is used in ONNX export. It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, with some functions modified to enable model tracing. Also supports extra options controlling what information. See the ONNX export script for details. """ def __init__( self, model: Sam, return_single_mask: bool, use_stability_score: bool = False, return_extra_metrics: bool = False, ) -> None: super().__init__() self.mask_decoder = model.mask_decoder self.model = model self.img_size = model.image_encoder.img_size self.return_single_mask = return_single_mask self.use_stability_score = use_stability_score self.stability_score_offset = 1.0 self.return_extra_metrics = return_extra_metrics @staticmethod def resize_longest_image_size( input_image_size: torch.Tensor, longest_side: int ) -> torch.Tensor: input_image_size = input_image_size.to(torch.float32) scale = longest_side / torch.max(input_image_size) transformed_size = scale * input_image_size transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) return transformed_size def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: point_coords = point_coords + 0.5 point_coords = point_coords / self.img_size point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) point_embedding = point_embedding * (point_labels != -1) point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( point_labels == -1 ) for i in range(self.model.prompt_encoder.num_point_embeddings): point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ i ].weight * (point_labels == i) return point_embedding def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) mask_embedding = mask_embedding + ( 1 - has_mask_input ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) return mask_embedding def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: masks = F.interpolate( masks, size=(self.img_size, self.img_size), mode="bilinear", align_corners=False, ) prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_im_size = orig_im_size.to(torch.int64) h, w = orig_im_size[0], orig_im_size[1] masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) return masks def select_masks( self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int ) -> Tuple[torch.Tensor, torch.Tensor]: # Determine if we should return the multiclick mask or not from the number of points. # The reweighting is used to avoid control flow. score_reweight = torch.tensor( [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] ).to(iou_preds.device) score = iou_preds + (num_points - 2.5) * score_reweight best_idx = torch.argmax(score, dim=1) masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) return masks, iou_preds def forward( self, image_embeddings: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, mask_input: torch.Tensor, has_mask_input: torch.Tensor, orig_im_size: torch.Tensor, ): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) masks, scores = self.model.mask_decoder.predict_masks( image_embeddings=image_embeddings, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embedding, dense_prompt_embeddings=dense_embedding, ) if self.use_stability_score: scores = calculate_stability_score( masks, self.model.mask_threshold, self.stability_score_offset ) if self.return_single_mask: masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) upscaled_masks = self.mask_postprocessing(masks, orig_im_size) if self.return_extra_metrics: stability_scores = calculate_stability_score( upscaled_masks, self.model.mask_threshold, self.stability_score_offset ) areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) return upscaled_masks, scores, stability_scores, areas, masks return upscaled_masks, scores, masks ================================================ FILE: modules/control/proc/segment_anything/utils/transforms.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch from torch.nn import functional as F from torchvision.transforms.functional import resize, to_pil_image # type: ignore from copy import deepcopy from typing import Tuple class ResizeLongestSide: """ Resizes images to the longest side 'target_length', as well as provides methods for resizing coordinates and boxes. Provides methods for transforming both numpy array and batched torch tensors. """ def __init__(self, target_length: int) -> None: self.target_length = target_length def apply_image(self, image: np.ndarray) -> np.ndarray: """ Expects a numpy array with shape HxWxC in uint8 format. """ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) return np.array(resize(to_pil_image(image), target_size)) def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: """ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. """ old_h, old_w = original_size new_h, new_w = self.get_preprocess_shape( original_size[0], original_size[1], self.target_length ) coords = deepcopy(coords).astype(float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: """ Expects a numpy array shape Bx4. Requires the original image size in (H, W) format. """ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) return boxes.reshape(-1, 4) def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: """ Expects batched images with shape BxCxHxW and float format. This transformation may not exactly match apply_image. apply_image is the transformation expected by the model. """ # Expects an image in BCHW format. May not exactly match apply_image. target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) return F.interpolate( image, target_size, mode="bilinear", align_corners=False, antialias=True ) def apply_coords_torch( self, coords: torch.Tensor, original_size: Tuple[int, ...] ) -> torch.Tensor: """ Expects a torch tensor with length 2 in the last dimension. Requires the original image size in (H, W) format. """ old_h, old_w = original_size new_h, new_w = self.get_preprocess_shape( original_size[0], original_size[1], self.target_length ) coords = deepcopy(coords).to(torch.float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords def apply_boxes_torch( self, boxes: torch.Tensor, original_size: Tuple[int, ...] ) -> torch.Tensor: """ Expects a torch tensor with shape Bx4. Requires the original image size in (H, W) format. """ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) return boxes.reshape(-1, 4) @staticmethod def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: """ Compute the output size given input size and target long side length. """ scale = long_side_length * 1.0 / max(oldh, oldw) newh, neww = oldh * scale, oldw * scale neww = int(neww + 0.5) newh = int(newh + 0.5) return (newh, neww) ================================================ FILE: modules/control/proc/shuffle.py ================================================ import warnings import random import cv2 import numpy as np from PIL import Image from modules.control.util import HWC3, img2mask, make_noise_disk, resize_image class ContentShuffleDetector: def __call__(self, input_image, h=None, w=None, f=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): if "return_pil" in kwargs: warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) output_type = "pil" if kwargs["return_pil"] else "np" if type(output_type) is bool: warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") if output_type: output_type = "pil" if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) H, W, _C = input_image.shape if h is None: h = H if w is None: w = W if f is None: f = 256 x = make_noise_disk(h, w, 1, f) * float(W - 1) y = make_noise_disk(h, w, 1, f) * float(H - 1) flow = np.concatenate([x, y], axis=2).astype(np.float32) detected_map = cv2.remap(input_image, flow, None, cv2.INTER_LINEAR) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map class ColorShuffleDetector: def __call__(self, img): H, W, C = img.shape F = np.random.randint(64, 384) # noqa A = make_noise_disk(H, W, 3, F) B = make_noise_disk(H, W, 3, F) C = (A + B) / 2.0 A = (C + (A - C) * 3.0).clip(0, 1) B = (C + (B - C) * 3.0).clip(0, 1) L = img.astype(np.float32) / 255.0 Y = A * L + B * (1 - L) Y -= np.min(Y, axis=(0, 1), keepdims=True) Y /= np.maximum(np.max(Y, axis=(0, 1), keepdims=True), 1e-5) Y *= 255.0 return Y.clip(0, 255).astype(np.uint8) class GrayDetector: def __call__(self, img): eps = 1e-5 X = img.astype(np.float32) r, g, b = X[:, :, 0], X[:, :, 1], X[:, :, 2] kr, kg, kb = [random.random() + eps for _ in range(3)] ks = kr + kg + kb kr /= ks kg /= ks kb /= ks Y = r * kr + g * kg + b * kb Y = np.stack([Y] * 3, axis=2) return Y.clip(0, 255).astype(np.uint8) class DownSampleDetector: def __call__(self, img, level=3, k=16.0): h = img.astype(np.float32) for _ in range(level): h += np.random.normal(loc=0.0, scale=k, size=h.shape) # noqa h = cv2.pyrDown(h) for _ in range(level): h = cv2.pyrUp(h) h += np.random.normal(loc=0.0, scale=k, size=h.shape) # noqa return h.clip(0, 255).astype(np.uint8) class Image2MaskShuffleDetector: def __init__(self, resolution=(640, 512)): self.H, self.W = resolution def __call__(self, img): m = img2mask(img, self.H, self.W) m *= 255.0 return m.clip(0, 255).astype(np.uint8) ================================================ FILE: modules/control/proc/zoe/LICENSE ================================================ MIT License Copyright (c) 2022 Intelligent Systems Lab Org Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: modules/control/proc/zoe/__init__.py ================================================ import os import cv2 import numpy as np import torch from einops import rearrange from huggingface_hub import hf_hub_download from PIL import Image import safetensors from modules import devices from modules.shared import opts from modules.control.util import HWC3, resize_image from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth from .zoedepth.models.zoedepth_nk.zoedepth_nk_v1 import ZoeDepthNK from .zoedepth.utils.config import get_config class ZoeDetector: def __init__(self, model): self.model = model @classmethod def from_pretrained(cls, pretrained_model_or_path, model_type="zoedepth", filename=None, cache_dir=None, local_files_only=False): filename = filename or "ZoeD_M12_N.pt" if os.path.isdir(pretrained_model_or_path): model_path = os.path.join(pretrained_model_or_path, filename) else: model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only) if model_type == "zoedepth": model_cls = ZoeDepth elif model_type == "zoedepth_nk": model_cls = ZoeDepthNK else: raise ValueError(f"ZoeDepth unknown model type {model_type}") conf = get_config(model_type, "infer") model = model_cls.build_from_config(conf) # model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model']) if model_path.lower().endswith('.safetensors'): model_dict = safetensors.torch.load_file(model_path, device='cpu') else: model_dict = torch.load(model_path, map_location=torch.device('cpu')) if hasattr(model_dict, 'model'): model_dict = model_dict['model'] model.load_state_dict(model_dict, strict=False) # timm compatibility issue for b in model.core.core.pretrained.model.blocks: b.drop_path = torch.nn.Identity() model.eval() return cls(model) def to(self, device): self.model.to(device) return self def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type=None, gamma_corrected=False): self.model.to(devices.device) device = next(iter(self.model.parameters())).device if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) output_type = output_type or "pil" else: output_type = output_type or "np" input_image = HWC3(input_image) input_image = resize_image(input_image, detect_resolution) assert input_image.ndim == 3 image_depth = input_image image_depth = torch.from_numpy(image_depth).float().to(device) image_depth = image_depth / 255.0 image_depth = rearrange(image_depth, 'h w c -> 1 c h w') depth = self.model.infer(image_depth) if opts.control_move_processor: self.model.to('cpu') depth = depth[0, 0].cpu().numpy() vmin = np.percentile(depth, 2) vmax = np.percentile(depth, 85) depth -= vmin depth /= vmax - vmin depth = 1.0 - depth if gamma_corrected: depth = np.power(depth, 2.2) depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) detected_map = depth_image detected_map = HWC3(detected_map) img = resize_image(input_image, image_resolution) H, W, _C = img.shape detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) if output_type == "pil": detected_map = Image.fromarray(detected_map) return detected_map ================================================ FILE: modules/control/proc/zoe/zoedepth/__init__.py ================================================ ================================================ FILE: modules/control/proc/zoe/zoedepth/models/__init__.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/__init__.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas.py ================================================ # MIT License import os # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch import torch.nn as nn import numpy as np from torchvision.transforms import Normalize def denormalize(x): """Reverses the imagenet normalization applied to the input. Args: x (torch.Tensor - shape(N,3,H,W)): input tensor Returns: torch.Tensor - shape(N,3,H,W): Denormalized input """ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) return x * std + mean def get_activation(name, bank): def hook(model, input, output): bank[name] = output return hook class Resize(object): """Resize sample to given size (width, height). """ def __init__( self, width, height, resize_target=True, keep_aspect_ratio=False, ensure_multiple_of=1, resize_method="lower_bound", ): """Init. Args: width (int): desired output width height (int): desired output height resize_target (bool, optional): True: Resize the full sample (image, mask, target). False: Resize image only. Defaults to True. keep_aspect_ratio (bool, optional): True: Keep the aspect ratio of the input sample. Output sample might not have the given width and height, and resize behaviour depends on the parameter 'resize_method'. Defaults to False. ensure_multiple_of (int, optional): Output width and height is constrained to be multiple of this parameter. Defaults to 1. resize_method (str, optional): "lower_bound": Output will be at least as large as the given size. "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) "minimal": Scale as least as possible. (Output size might be smaller than given size.) Defaults to "lower_bound". """ # print("Params passed to Resize transform:") # print("\twidth: ", width) # print("\theight: ", height) # print("\tresize_target: ", resize_target) # print("\tkeep_aspect_ratio: ", keep_aspect_ratio) # print("\tensure_multiple_of: ", ensure_multiple_of) # print("\tresize_method: ", resize_method) self.__width = width self.__height = height self.__keep_aspect_ratio = keep_aspect_ratio self.__multiple_of = ensure_multiple_of self.__resize_method = resize_method def constrain_to_multiple_of(self, x, min_val=0, max_val=None): y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) if max_val is not None and y > max_val: y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) if y < min_val: y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) return y def get_size(self, width, height): # determine new height and width scale_height = self.__height / height scale_width = self.__width / width if self.__keep_aspect_ratio: if self.__resize_method == "lower_bound": # scale such that output size is lower bound if scale_width > scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "upper_bound": # scale such that output size is upper bound if scale_width < scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "minimal": # scale as least as possbile if abs(1 - scale_width) < abs(1 - scale_height): # fit width scale_height = scale_width else: # fit height scale_width = scale_height else: raise ValueError( f"resize_method {self.__resize_method} not implemented" ) if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of( scale_height * height, min_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, min_val=self.__width ) elif self.__resize_method == "upper_bound": new_height = self.constrain_to_multiple_of( scale_height * height, max_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, max_val=self.__width ) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: raise ValueError( f"resize_method {self.__resize_method} not implemented") return (new_width, new_height) def __call__(self, x): width, height = self.get_size(*x.shape[-2:][::-1]) return nn.functional.interpolate(x, (int(width), int(height)), mode='bilinear', align_corners=True) class PrepForMidas(object): def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): if isinstance(img_size, int): img_size = (img_size, img_size) net_h, net_w = img_size self.normalization = Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ if do_resize else nn.Identity() def __call__(self, x): return self.normalization(self.resizer(x)) class MidasCore(nn.Module): def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, img_size=384, **kwargs): """Midas Base model used for multi-scale feature extraction. Args: midas (torch.nn.Module): Midas model. trainable (bool, optional): Train midas model. Defaults to False. fetch_features (bool, optional): Extract multi-scale features. Defaults to True. layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. img_size (int, tuple, optional): Input resolution. Defaults to 384. """ super().__init__() self.core = midas self.output_channels = None self.core_out = {} self.trainable = trainable self.fetch_features = fetch_features # midas.scratch.output_conv = nn.Identity() self.handles = [] # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] self.layer_names = layer_names self.set_trainable(trainable) self.set_fetch_features(fetch_features) self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, img_size=img_size, do_resize=kwargs.get('do_resize', True)) if freeze_bn: self.freeze_bn() def set_trainable(self, trainable): self.trainable = trainable if trainable: self.unfreeze() else: self.freeze() return self def set_fetch_features(self, fetch_features): self.fetch_features = fetch_features if fetch_features: if len(self.handles) == 0: self.attach_hooks(self.core) else: self.remove_hooks() return self def freeze(self): for p in self.parameters(): p.requires_grad = False self.trainable = False return self def unfreeze(self): for p in self.parameters(): p.requires_grad = True self.trainable = True return self def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() return self def forward(self, x, denorm=False, return_rel_depth=False): if denorm: x = denormalize(x) x = self.prep(x) # print("Shape after prep: ", x.shape) with torch.set_grad_enabled(self.trainable): # print("Input size to Midascore", x.shape) rel_depth = self.core(x) # print("Output from midas shape", rel_depth.shape) if not self.fetch_features: return rel_depth out = [self.core_out[k] for k in self.layer_names] if return_rel_depth: return rel_depth, out return out def get_rel_pos_params(self): for name, p in self.core.pretrained.named_parameters(): if "relative_position" in name: yield p def get_enc_params_except_rel_pos(self): for name, p in self.core.pretrained.named_parameters(): if "relative_position" not in name: yield p def freeze_encoder(self, freeze_rel_pos=False): if freeze_rel_pos: for p in self.core.pretrained.parameters(): p.requires_grad = False else: for p in self.get_enc_params_except_rel_pos(): p.requires_grad = False return self def attach_hooks(self, midas): if len(self.handles) > 0: self.remove_hooks() if "out_conv" in self.layer_names: self.handles.append(list(midas.scratch.output_conv.children())[ 3].register_forward_hook(get_activation("out_conv", self.core_out))) if "r4" in self.layer_names: self.handles.append(midas.scratch.refinenet4.register_forward_hook( get_activation("r4", self.core_out))) if "r3" in self.layer_names: self.handles.append(midas.scratch.refinenet3.register_forward_hook( get_activation("r3", self.core_out))) if "r2" in self.layer_names: self.handles.append(midas.scratch.refinenet2.register_forward_hook( get_activation("r2", self.core_out))) if "r1" in self.layer_names: self.handles.append(midas.scratch.refinenet1.register_forward_hook( get_activation("r1", self.core_out))) if "l4_rn" in self.layer_names: self.handles.append(midas.scratch.layer4_rn.register_forward_hook( get_activation("l4_rn", self.core_out))) return self def remove_hooks(self): for h in self.handles: h.remove() return self def __del__(self): self.remove_hooks() def set_output_channels(self, model_type): self.output_channels = MIDAS_SETTINGS[model_type] @staticmethod def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): if midas_model_type not in MIDAS_SETTINGS: raise ValueError( f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") if "img_size" in kwargs: kwargs = MidasCore.parse_img_size(kwargs) img_size = kwargs.pop("img_size", [384, 384]) # print("img_size", img_size) midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') midas = torch.hub.load(midas_path, midas_model_type, pretrained=use_pretrained_midas, force_reload=force_reload, source='local') kwargs.update({'keep_aspect_ratio': force_keep_ar}) midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, freeze_bn=freeze_bn, img_size=img_size, **kwargs) midas_core.set_output_channels(midas_model_type) return midas_core @staticmethod def build_from_config(config): return MidasCore.build(**config) @staticmethod def parse_img_size(config): assert 'img_size' in config if isinstance(config['img_size'], str): assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" config['img_size'] = list(map(int, config['img_size'].split(","))) assert len( config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" elif isinstance(config['img_size'], int): config['img_size'] = [config['img_size'], config['img_size']] else: assert isinstance(config['img_size'], list) and len( config['img_size']) == 2, "img_size should be a list of H,W" return config nchannels2models = { tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], (512, 256, 128, 64, 64): ["MiDaS_small"] } # Model name to number of output channels MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() for m in v } ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/LICENSE ================================================ MIT License Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/README.md ================================================ ## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): >Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun and our [preprint](https://arxiv.org/abs/2103.13413): > Vision Transformers for Dense Prediction > René Ranftl, Alexey Bochkovskiy, Vladlen Koltun MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with multi-objective optimization. The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. ![](figures/Improvement_vs_FPS.png) ### Setup 1) Pick one or more models and download the corresponding weights to the `weights` folder: MiDaS 3.1 - For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) - For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) - For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) - For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) 1) Set up dependencies: ```shell conda env create -f environment.yaml conda activate midas-py310 ``` #### optional For the Next-ViT model, execute ```shell git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit ``` For the OpenVINO model, install ```shell pip install openvino ``` ### Usage 1) Place one or more input images in the folder `input`. 2) Run the model with ```shell python run.py --model_type --input_path input --output_path output ``` where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). 3) The resulting depth maps are written to the `output` folder. #### optional 1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single inference height but a range of different heights. Feel free to explore different heights by appending the extra command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may decrease the model accuracy. 2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, disregarding the aspect ratio while preserving the height, use the command line argument `--square`. #### via Camera If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths away and choose a model type as shown above: ```shell python run.py --model_type --side ``` The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown side-by-side for comparison. #### via Docker 1) Make sure you have installed Docker and the [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). 2) Build the Docker image: ```shell docker build -t midas . ``` 3) Run inference: ```shell docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas ``` This command passes through all of your NVIDIA GPUs to the container, mounts the `input` and `output` directories and then runs the inference. #### via PyTorch Hub The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) #### via TensorFlow or ONNX See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. Currently only supports MiDaS v2.1. #### via Mobile (iOS / Android) See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. #### via ROS1 (Robot Operating System) See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. Currently only supports MiDaS v2.1. DPT-based models to be added. ### Accuracy We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets (see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. $\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by the numbers in the model names. The table also shows the **number of parameters** (in millions) and the **frames per second** for inference at the training resolution (for GPU RTX 3090): | MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | |-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| | **Inference height 512** | | | | | | | | | | | [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | | [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | | | | | | | | | | | | | **Inference height 384** | | | | | | | | | | | [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | | [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | | [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | | [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | | [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | | [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | | [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | | [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | | [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | | [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | | | | | | | | | | | | | **Inference height 256** | | | | | | | | | | | [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | | [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | | | | | | | | | | | | | **Inference height 224** | | | | | | | | | | | [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | * No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ $\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the improvement, because these quantities are averages over the pixels of an image and do not take into account the advantage of more details due to a higher resolution.\ Best values per column and same validation height in bold #### Improvement The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 and v2.0 Large384 respectively instead of v3.0 DPTL-384. ### Depth map comparison Zoom in for better visibility ![](figures/Comparison.png) ### Speed on Camera Feed Test configuration - Windows 10 - 11th Gen Intel Core i7-1185G7 3.00GHz - 16GB RAM - Camera resolution 640x480 - openvino_midas_v21_small_256 Speed: 22 FPS ### Changelog * [Dec 2022] Released MiDaS v3.1: - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 - Integrated live depth estimation from camera feed * [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). * [Apr 2021] Released MiDaS v3.0: - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 - Additional models can be found [here](https://github.com/isl-org/DPT) * [Nov 2020] Released MiDaS v2.1: - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots * [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). * [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust * [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) ### Citation Please cite our paper if you use this code or any of the models: ``` @ARTICLE {Ranftl2022, author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", year = "2022", volume = "44", number = "3" } ``` If you use a DPT-based model, please also cite: ``` @article{Ranftl2021, author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, title = {Vision Transformers for Dense Prediction}, journal = {ICCV}, year = {2021}, } ``` ### Acknowledgements Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). We'd like to thank the authors for making these libraries available. ### License MIT License ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/__init__.py ================================================ ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/hubconf.py ================================================ dependencies = ["torch"] import torch from midas.dpt_depth import DPTDepthModel from midas.midas_net import MidasNet from midas.midas_net_custom import MidasNet_small def DPT_BEiT_L_512(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_BEiT_L_512 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="beitl16_512", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_BEiT_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_BEiT_L_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="beitl16_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_BEiT_B_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_BEiT_B_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="beitb16_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_SwinV2_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_SwinV2_L_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="swin2l24_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_SwinV2_B_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_SwinV2_B_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="swin2b24_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_SwinV2_T_256(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_SwinV2_T_256 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="swin2t16_256", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_Swin_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_Swin_L_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="swinl12_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_Next_ViT_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="next_vit_large_6m", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_LeViT_224(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_LeViT_224 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="levit_384", non_negative=True, head_features_1=64, head_features_2=8, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_Large(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT-Large model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="vitl16_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def DPT_Hybrid(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT-Hybrid model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = DPTDepthModel( path=None, backbone="vitb_rn50_384", non_negative=True, ) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def MiDaS(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS v2.1 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ model = MidasNet() if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def MiDaS_small(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices pretrained (bool): load pretrained weights into model """ model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) if pretrained: checkpoint = ( "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" ) state_dict = torch.hub.load_state_dict_from_url( checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True ) model.load_state_dict(state_dict) return model def transforms(): import cv2 from torchvision.transforms import Compose from midas.transforms import Resize, NormalizeImage, PrepareForNet from midas import transforms transforms.default_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 384, 384, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="upper_bound", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) transforms.small_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 256, 256, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="upper_bound", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) transforms.dpt_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 384, 384, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) transforms.beit512_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 512, 512, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) transforms.swin384_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 384, 384, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) transforms.swin256_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 256, 256, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) transforms.levit_transform = Compose( [ lambda img: {"image": img / 255.0}, Resize( 224, 224, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), ] ) return transforms ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/__init__.py ================================================ ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/__init__.py ================================================ ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py ================================================ import timm import torch import types import numpy as np import torch.nn.functional as F from .utils import forward_adapted_unflatten, make_backbone_default from timm.models.beit import gen_relative_position_index from torch.utils.checkpoint import checkpoint from typing import Optional def forward_beit(pretrained, x): return forward_adapted_unflatten(pretrained, x, "forward_features") def patch_embed_forward(self, x): """ Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. """ x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x def _get_rel_pos_bias(self, window_size): """ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. """ old_height = 2 * self.window_size[0] - 1 old_width = 2 * self.window_size[1] - 1 new_height = 2 * window_size[0] - 1 new_width = 2 * window_size[1] - 1 old_relative_position_bias_table = self.relative_position_bias_table old_num_relative_distance = self.num_relative_distance new_num_relative_distance = new_height * new_width + 3 old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) new_sub_table = F.interpolate(old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear") new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) new_relative_position_bias_table = torch.cat( [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) key = str(window_size[1]) + "," + str(window_size[0]) if key not in self.relative_position_indices.keys(): self.relative_position_indices[key] = gen_relative_position_index(window_size) relative_position_bias = new_relative_position_bias_table[ self.relative_position_indices[key].view(-1)].view( window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww return relative_position_bias.unsqueeze(0) def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): """ Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. """ B, N, C = x.shape qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: window_size = tuple(np.array(resolution) // 16) attn = attn + self._get_rel_pos_bias(window_size) if shared_rel_pos_bias is not None: attn = attn + shared_rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): """ Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. """ if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x def beit_forward_features(self, x): """ Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. """ resolution = x.shape[2:] x = self.patch_embed(x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) else: x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) x = self.norm(x) return x def _make_beit_backbone( model, features=None, size=None, hooks=None, vit_features=768, use_readout="ignore", start_index=1, start_index_readout=1, ): if hooks is None: hooks = [0, 4, 8, 11] if size is None: size = [384, 384] if features is None: features = [96, 192, 384, 768] backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, start_index_readout) backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) for block in backbone.model.blocks: attn = block.attn attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) attn.forward = types.MethodType(attention_forward, attn) attn.relative_position_indices = {} block.forward = types.MethodType(block_forward, block) return backbone def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) hooks = [5, 11, 17, 23] if hooks is None else hooks features = [256, 512, 1024, 1024] return _make_beit_backbone( model, features=features, size=[512, 512], hooks=hooks, vit_features=1024, use_readout=use_readout, ) def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) hooks = [5, 11, 17, 23] if hooks is None else hooks return _make_beit_backbone( model, features=[256, 512, 1024, 1024], hooks=hooks, vit_features=1024, use_readout=use_readout, ) def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_beit_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout, ) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py ================================================ import timm import torch import torch.nn as nn import numpy as np from .utils import activations, get_activation, Transpose def forward_levit(pretrained, x): pretrained.model.forward_features(x) layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] layer_3 = pretrained.activations["3"] layer_1 = pretrained.act_postprocess1(layer_1) layer_2 = pretrained.act_postprocess2(layer_2) layer_3 = pretrained.act_postprocess3(layer_3) return layer_1, layer_2, layer_3 def _make_levit_backbone( model, hooks=None, patch_grid=None ): if patch_grid is None: patch_grid = [14, 14] if hooks is None: hooks = [3, 11, 21] pretrained = nn.Module() pretrained.model = model pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.activations = activations patch_grid_size = np.array(patch_grid, dtype=int) pretrained.act_postprocess1 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) ) pretrained.act_postprocess2 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) ) pretrained.act_postprocess3 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) ) return pretrained class ConvTransposeNorm(nn.Sequential): """ Modification of https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm such that ConvTranspose2d is used instead of Conv2d. """ def __init__( self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() self.add_module('c', nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) self.add_module('bn', nn.BatchNorm2d(out_chs)) nn.init.constant_(self.bn.weight, bn_weight_init) def fuse(self): c, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 m = nn.ConvTranspose2d( w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def stem_b4_transpose(in_chs, out_chs, activation): """ Modification of https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. """ return nn.Sequential( ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), activation(), ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), activation()) def _make_pretrained_levit_384(pretrained, hooks=None): model = timm.create_model("levit_384", pretrained=pretrained) hooks = [3, 11, 21] if hooks is None else hooks return _make_levit_backbone( model, hooks=hooks ) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py ================================================ import timm import torch.nn as nn from .utils import activations, forward_default, get_activation from ..external.next_vit.classification.nextvit import * # noqa def forward_next_vit(pretrained, x): return forward_default(pretrained, x, "forward") def _make_next_vit_backbone( model, hooks=None, ): if hooks is None: hooks = [2, 6, 36, 39] pretrained = nn.Module() pretrained.model = model pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations return pretrained def _make_pretrained_next_vit_large_6m(hooks=None): model = timm.create_model("nextvit_large") hooks = [2, 6, 36, 39] if hooks is None else hooks return _make_next_vit_backbone( model, hooks=hooks, ) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py ================================================ import timm from .swin_common import _make_swin_backbone def _make_pretrained_swinl12_384(pretrained, hooks=None): model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) hooks = [1, 1, 17, 1] if hooks is None else hooks return _make_swin_backbone( model, hooks=hooks ) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py ================================================ import timm from .swin_common import _make_swin_backbone def _make_pretrained_swin2l24_384(pretrained, hooks=None): model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) hooks = [1, 1, 17, 1] if hooks is None else hooks return _make_swin_backbone( model, hooks=hooks ) def _make_pretrained_swin2b24_384(pretrained, hooks=None): model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) hooks = [1, 1, 17, 1] if hooks is None else hooks return _make_swin_backbone( model, hooks=hooks ) def _make_pretrained_swin2t16_256(pretrained, hooks=None): model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) hooks = [1, 1, 5, 1] if hooks is None else hooks return _make_swin_backbone( model, hooks=hooks, patch_grid=[64, 64] ) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py ================================================ import torch import torch.nn as nn import numpy as np from .utils import activations, forward_default, get_activation, Transpose def forward_swin(pretrained, x): return forward_default(pretrained, x) def _make_swin_backbone( model, hooks=None, patch_grid=None ): if patch_grid is None: patch_grid = [96, 96] if hooks is None: hooks = [1, 1, 17, 1] pretrained = nn.Module() pretrained.model = model pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations if hasattr(model, "patch_grid"): used_patch_grid = model.patch_grid else: used_patch_grid = patch_grid patch_grid_size = np.array(used_patch_grid, dtype=int) pretrained.act_postprocess1 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) ) pretrained.act_postprocess2 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) ) pretrained.act_postprocess3 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) ) pretrained.act_postprocess4 = nn.Sequential( Transpose(1, 2), nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) ) return pretrained ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py ================================================ import torch import torch.nn as nn class Slice(nn.Module): def __init__(self, start_index=1): super(Slice, self).__init__() self.start_index = start_index def forward(self, x): return x[:, self.start_index:] class AddReadout(nn.Module): def __init__(self, start_index=1): super(AddReadout, self).__init__() self.start_index = start_index def forward(self, x): if self.start_index == 2: readout = (x[:, 0] + x[:, 1]) / 2 else: readout = x[:, 0] return x[:, self.start_index:] + readout.unsqueeze(1) class ProjectReadout(nn.Module): def __init__(self, in_features, start_index=1): super(ProjectReadout, self).__init__() self.start_index = start_index self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) def forward(self, x): readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) features = torch.cat((x[:, self.start_index:], readout), -1) return self.project(features) class Transpose(nn.Module): def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x): x = x.transpose(self.dim0, self.dim1) return x activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output return hook def forward_default(pretrained, x, function_name="forward_features"): exec(f"pretrained.model.{function_name}(x)") layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] layer_3 = pretrained.activations["3"] layer_4 = pretrained.activations["4"] if hasattr(pretrained, "act_postprocess1"): layer_1 = pretrained.act_postprocess1(layer_1) if hasattr(pretrained, "act_postprocess2"): layer_2 = pretrained.act_postprocess2(layer_2) if hasattr(pretrained, "act_postprocess3"): layer_3 = pretrained.act_postprocess3(layer_3) if hasattr(pretrained, "act_postprocess4"): layer_4 = pretrained.act_postprocess4(layer_4) return layer_1, layer_2, layer_3, layer_4 def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): b, c, h, w = x.shape exec(f"glob = pretrained.model.{function_name}(x)") layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] layer_3 = pretrained.activations["3"] layer_4 = pretrained.activations["4"] layer_1 = pretrained.act_postprocess1[0:2](layer_1) layer_2 = pretrained.act_postprocess2[0:2](layer_2) layer_3 = pretrained.act_postprocess3[0:2](layer_3) layer_4 = pretrained.act_postprocess4[0:2](layer_4) unflatten = nn.Sequential( nn.Unflatten( 2, torch.Size( [ h // pretrained.model.patch_size[1], w // pretrained.model.patch_size[0], ] ), ) ) if layer_1.ndim == 3: layer_1 = unflatten(layer_1) if layer_2.ndim == 3: layer_2 = unflatten(layer_2) if layer_3.ndim == 3: layer_3 = unflatten(layer_3) if layer_4.ndim == 3: layer_4 = unflatten(layer_4) layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) return layer_1, layer_2, layer_3, layer_4 def get_readout_oper(vit_features, features, use_readout, start_index=1): if use_readout == "ignore": readout_oper = [Slice(start_index)] * len(features) elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": readout_oper = [ ProjectReadout(vit_features, start_index) for out_feat in features ] else: raise AssertionError("wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'") return readout_oper def make_backbone_default( model, features=None, size=None, hooks=None, vit_features=768, use_readout="ignore", start_index=1, start_index_readout=1, ): if hooks is None: hooks = [2, 5, 8, 11] if size is None: size = [384, 384] if features is None: features = [96, 192, 384, 768] pretrained = nn.Module() pretrained.model = model pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) # 32, 48, 136, 384 pretrained.act_postprocess1 = nn.Sequential( readout_oper[0], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[0], out_channels=features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess2 = nn.Sequential( readout_oper[1], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=features[1], out_channels=features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ), ) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[2], kernel_size=1, stride=1, padding=0, ), ) pretrained.act_postprocess4 = nn.Sequential( readout_oper[3], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ), ) pretrained.model.start_index = start_index pretrained.model.patch_size = [16, 16] return pretrained ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py ================================================ import torch import torch.nn as nn import timm import types import math import torch.nn.functional as F from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, make_backbone_default, Transpose) def forward_vit(pretrained, x): return forward_adapted_unflatten(pretrained, x, "forward_flex") def _resize_pos_embed(self, posemb, gs_h, gs_w): posemb_tok, posemb_grid = ( posemb[:, : self.start_index], posemb[0, self.start_index:], ) gs_old = int(math.sqrt(len(posemb_grid))) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def forward_flex(self, x): b, c, h, w = x.shape pos_embed = self._resize_pos_embed( self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] ) B = x.shape[0] if hasattr(self.patch_embed, "backbone"): x = self.patch_embed.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] # last feature if backbone outputs list/tuple of features x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: if self.no_embed_class: x = x + pos_embed cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if not self.no_embed_class: x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x def _make_vit_b16_backbone( model, features=None, size=None, hooks=None, vit_features=768, use_readout="ignore", start_index=1, start_index_readout=1, ): if hooks is None: hooks = [2, 5, 8, 11] if size is None: size = [384, 384] if features is None: features = [96, 192, 384, 768] pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, start_index_readout) # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) hooks = [5, 11, 17, 23] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[256, 512, 1024, 1024], hooks=hooks, vit_features=1024, use_readout=use_readout, ) def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout ) def _make_vit_b_rn50_backbone( model, features=None, size=None, hooks=None, vit_features=768, patch_size=None, number_stages=2, use_vit_only=False, use_readout="ignore", start_index=1, ): if patch_size is None: patch_size = [16, 16] if hooks is None: hooks = [0, 1, 8, 11] if size is None: size = [384, 384] if features is None: features = [256, 512, 768, 768] pretrained = nn.Module() pretrained.model = model used_number_stages = 0 if use_vit_only else number_stages for s in range(used_number_stages): pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( get_activation(str(s + 1)) ) for s in range(used_number_stages, 4): pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) pretrained.activations = activations readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) for s in range(used_number_stages): nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) exec(f"pretrained.act_postprocess{s + 1}=value") for s in range(used_number_stages, 4): if s < number_stages: final_layer = nn.ConvTranspose2d( in_channels=features[s], out_channels=features[s], kernel_size=4 // (2 ** s), stride=4 // (2 ** s), padding=0, bias=True, dilation=1, groups=1, ) elif s > number_stages: final_layer = nn.Conv2d( in_channels=features[3], out_channels=features[3], kernel_size=3, stride=2, padding=1, ) else: final_layer = None layers = [ readout_oper[s], Transpose(1, 2), nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), nn.Conv2d( in_channels=vit_features, out_channels=features[s], kernel_size=1, stride=1, padding=0, ), ] if final_layer is not None: layers.append(final_layer) nn.Sequential(*layers) exec(f"pretrained.act_postprocess{s + 1}=value") pretrained.model.start_index = start_index pretrained.model.patch_size = patch_size # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def _make_pretrained_vitb_rn50_384( pretrained, use_readout="ignore", hooks=None, use_vit_only=False ): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks is None else hooks return _make_vit_b_rn50_backbone( model, features=[256, 512, 768, 768], size=[384, 384], hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, ) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py ================================================ import torch class BaseModel(torch.nn.Module): def load(self, path): """Load model from file. Args: path (str): file path """ parameters = torch.load(path, map_location=torch.device('cpu')) if "optimizer" in parameters: parameters = parameters["model"] self.load_state_dict(parameters) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py ================================================ import torch import torch.nn as nn from .backbones.beit import ( _make_pretrained_beitl16_512, _make_pretrained_beitl16_384, _make_pretrained_beitb16_384, forward_beit, ) from .backbones.swin_common import ( forward_swin, ) from .backbones.swin2 import ( _make_pretrained_swin2l24_384, _make_pretrained_swin2b24_384, _make_pretrained_swin2t16_256, ) from .backbones.swin import ( _make_pretrained_swinl12_384, ) from .backbones.levit import ( _make_pretrained_levit_384, forward_levit, ) from .backbones.vit import ( _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384, _make_pretrained_vitb16_384, forward_vit, ) def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore", in_features=None): if in_features is None: in_features = [96, 256, 512, 1024] if backbone == "beitl16_512": pretrained = _make_pretrained_beitl16_512( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # BEiT_512-L (backbone) elif backbone == "beitl16_384": pretrained = _make_pretrained_beitl16_384( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # BEiT_384-L (backbone) elif backbone == "beitb16_384": pretrained = _make_pretrained_beitb16_384( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # BEiT_384-B (backbone) elif backbone == "swin2l24_384": pretrained = _make_pretrained_swin2l24_384( use_pretrained, hooks=hooks ) scratch = _make_scratch( [192, 384, 768, 1536], features, groups=groups, expand=expand ) # Swin2-L/12to24 (backbone) elif backbone == "swin2b24_384": pretrained = _make_pretrained_swin2b24_384( use_pretrained, hooks=hooks ) scratch = _make_scratch( [128, 256, 512, 1024], features, groups=groups, expand=expand ) # Swin2-B/12to24 (backbone) elif backbone == "swin2t16_256": pretrained = _make_pretrained_swin2t16_256( use_pretrained, hooks=hooks ) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # Swin2-T/16 (backbone) elif backbone == "swinl12_384": pretrained = _make_pretrained_swinl12_384( use_pretrained, hooks=hooks ) scratch = _make_scratch( [192, 384, 768, 1536], features, groups=groups, expand=expand ) # Swin-L/12 (backbone) elif backbone == "next_vit_large_6m": from .backbones.next_vit import _make_pretrained_next_vit_large_6m pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) scratch = _make_scratch( in_features, features, groups=groups, expand=expand ) # Next-ViT-L on ImageNet-1K-6M (backbone) elif backbone == "levit_384": pretrained = _make_pretrained_levit_384( use_pretrained, hooks=hooks ) scratch = _make_scratch( [384, 512, 768], features, groups=groups, expand=expand ) # LeViT 384 (backbone) elif backbone == "vitl16_384": pretrained = _make_pretrained_vitl16_384( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) elif backbone == "vitb_rn50_384": pretrained = _make_pretrained_vitb_rn50_384( use_pretrained, hooks=hooks, use_vit_only=use_vit_only, use_readout=use_readout, ) scratch = _make_scratch( [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": pretrained = _make_pretrained_vitb16_384( use_pretrained, hooks=hooks, use_readout=use_readout ) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") raise AssertionError return pretrained, scratch def _make_scratch(in_shape, out_shape, groups=1, expand=False): scratch = nn.Module() out_shape1 = out_shape out_shape2 = out_shape out_shape3 = out_shape if len(in_shape) >= 4: out_shape4 = out_shape if expand: out_shape1 = out_shape out_shape2 = out_shape*2 out_shape3 = out_shape*4 if len(in_shape) >= 4: out_shape4 = out_shape*8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer2_rn = nn.Conv2d( in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) scratch.layer3_rn = nn.Conv2d( in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) if len(in_shape) >= 4: scratch.layer4_rn = nn.Conv2d( in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups ) return scratch def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): efficientnet = torch.hub.load( "rwightman/gen-efficientnet-pytorch", "tf_efficientnet_lite3", pretrained=use_pretrained, exportable=exportable ) return _make_efficientnet_backbone(efficientnet) def _make_efficientnet_backbone(effnet): pretrained = nn.Module() pretrained.layer1 = nn.Sequential( effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] ) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained def _make_resnet_backbone(resnet): pretrained = nn.Module() pretrained.layer1 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 ) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 pretrained.layer4 = resnet.layer4 return pretrained def _make_pretrained_resnext101_wsl(use_pretrained): resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") return _make_resnet_backbone(resnet) class Interpolate(nn.Module): """Interpolation module. """ def __init__(self, scale_factor, mode, align_corners=False): """Init. Args: scale_factor (float): scaling mode (str): interpolation mode """ super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: interpolated data """ x = self.interp( x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners ) return x class ResidualConvUnit(nn.Module): """Residual convolution module. """ def __init__(self, features): """Init. Args: features (int): number of features """ super().__init__() self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True ) self.relu = nn.ReLU(inplace=True) def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.relu(x) out = self.conv1(out) out = self.relu(out) out = self.conv2(out) return out + x class FeatureFusionBlock(nn.Module): """Feature fusion block. """ def __init__(self, features): """Init. Args: features (int): number of features """ super(FeatureFusionBlock, self).__init__() self.resConfUnit1 = ResidualConvUnit(features) self.resConfUnit2 = ResidualConvUnit(features) def forward(self, *xs): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: output += self.resConfUnit1(xs[1]) output = self.resConfUnit2(output) output = nn.functional.interpolate( output, scale_factor=2, mode="bilinear", align_corners=True ) return output class ResidualConvUnit_custom(nn.Module): """Residual convolution module. """ def __init__(self, features, activation, bn): """Init. Args: features (int): number of features """ super().__init__() self.bn = bn self.groups=1 self.conv1 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups ) self.conv2 = nn.Conv2d( features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups ) if self.bn is True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) self.activation = activation self.skip_add = nn.quantized.FloatFunctional() def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: output """ out = self.activation(x) out = self.conv1(out) if self.bn is True: out = self.bn1(out) out = self.activation(out) out = self.conv2(out) if self.bn is True: out = self.bn2(out) if self.groups > 1: out = self.conv_merge(out) return self.skip_add.add(out, x) # return out + x class FeatureFusionBlock_custom(nn.Module): """Feature fusion block. """ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): """Init. Args: features (int): number of features """ super(FeatureFusionBlock_custom, self).__init__() self.deconv = deconv self.align_corners = align_corners self.groups=1 self.expand = expand out_features = features if self.expand is True: out_features = features//2 self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) self.skip_add = nn.quantized.FloatFunctional() self.size=size def forward(self, *xs, size=None): """Forward pass. Returns: tensor: output """ output = xs[0] if len(xs) == 2: res = self.resConfUnit1(xs[1]) output = self.skip_add.add(output, res) # output += res output = self.resConfUnit2(output) if (size is None) and (self.size is None): modifier = {"scale_factor": 2} elif size is None: modifier = {"size": self.size} else: modifier = {"size": size} output = nn.functional.interpolate( output, **modifier, mode="bilinear", align_corners=self.align_corners ) output = self.out_conv(output) return output ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py ================================================ import torch import torch.nn as nn from .base_model import BaseModel from .blocks import ( FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_beit, forward_swin, forward_levit, forward_vit, ) from .backbones.levit import stem_b4_transpose from timm.models.layers import get_act_layer def _make_fusion_block(features, use_bn, size = None): return FeatureFusionBlock_custom( features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True, size=size, ) class DPT(BaseModel): def __init__( self, head, features=256, backbone="vitb_rn50_384", readout="project", channels_last=False, use_bn=False, **kwargs ): super(DPT, self).__init__() self.channels_last = channels_last # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. hooks = { "beitl16_512": [5, 11, 17, 23], "beitl16_384": [5, 11, 17, 23], "beitb16_384": [2, 5, 8, 11], "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] "vitb_rn50_384": [0, 1, 8, 11], "vitb16_384": [2, 5, 8, 11], "vitl16_384": [5, 11, 17, 23], }[backbone] if "next_vit" in backbone: in_features = { "next_vit_large_6m": [96, 256, 512, 1024], }[backbone] else: in_features = None # Instantiate backbone and reassemble blocks self.pretrained, self.scratch = _make_encoder( backbone, features, False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, hooks=hooks, use_readout=readout, in_features=in_features, ) self.number_layers = len(hooks) if hooks is not None else 4 size_refinenet3 = None self.scratch.stem_transpose = None if "beit" in backbone: self.forward_transformer = forward_beit elif "swin" in backbone: self.forward_transformer = forward_swin elif "next_vit" in backbone: from .backbones.next_vit import forward_next_vit self.forward_transformer = forward_next_vit elif "levit" in backbone: self.forward_transformer = forward_levit size_refinenet3 = 7 self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) else: self.forward_transformer = forward_vit self.scratch.refinenet1 = _make_fusion_block(features, use_bn) self.scratch.refinenet2 = _make_fusion_block(features, use_bn) self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) if self.number_layers >= 4: self.scratch.refinenet4 = _make_fusion_block(features, use_bn) self.scratch.output_conv = head def forward(self, x): if self.channels_last is True: x.contiguous(memory_format=torch.channels_last) layers = self.forward_transformer(self.pretrained, x) if self.number_layers == 3: layer_1, layer_2, layer_3 = layers else: layer_1, layer_2, layer_3, layer_4 = layers layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) if self.number_layers >= 4: layer_4_rn = self.scratch.layer4_rn(layer_4) if self.number_layers == 3: path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) else: path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) if self.scratch.stem_transpose is not None: path_1 = self.scratch.stem_transpose(path_1) out = self.scratch.output_conv(path_1) return out class DPTDepthModel(DPT): def __init__(self, path=None, non_negative=True, **kwargs): features = kwargs["features"] if "features" in kwargs else 256 head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 kwargs.pop("head_features_1", None) kwargs.pop("head_features_2", None) head = nn.Sequential( nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) super().__init__(head, **kwargs) if path is not None: self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py ================================================ """MidashNet: Network for monocular depth estimation trained by mixing several datasets. This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ import torch import torch.nn as nn from .base_model import BaseModel from .blocks import FeatureFusionBlock, Interpolate, _make_encoder class MidasNet(BaseModel): """Network for monocular depth estimation. """ def __init__(self, path=None, features=256, non_negative=True): """Init. Args: path (str, optional): Path to saved model. Defaults to None. features (int, optional): Number of features. Defaults to 256. backbone (str, optional): Backbone network for encoder. Defaults to resnet50 """ print("Loading weights: ", path) super(MidasNet, self).__init__() use_pretrained = False if path is None else True self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) self.scratch.refinenet2 = FeatureFusionBlock(features) self.scratch.refinenet1 = FeatureFusionBlock(features) self.scratch.output_conv = nn.Sequential( nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), Interpolate(scale_factor=2, mode="bilinear"), nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), ) if path: self.load(path) def forward(self, x): """Forward pass. Args: x (tensor): input data (image) Returns: tensor: depth """ layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py ================================================ """MidashNet: Network for monocular depth estimation trained by mixing several datasets. This file contains code that is adapted from https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py """ import torch import torch.nn as nn from .base_model import BaseModel from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder class MidasNet_small(BaseModel): """Network for monocular depth estimation. """ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, blocks=None): """Init. Args: path (str, optional): Path to saved model. Defaults to None. features (int, optional): Number of features. Defaults to 256. backbone (str, optional): Backbone network for encoder. Defaults to resnet50 """ if blocks is None: blocks = {"expand": True} print("Loading weights: ", path) super(MidasNet_small, self).__init__() use_pretrained = False if path else True self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 features1=features features2=features features3=features features4=features self.expand = False if "expand" in self.blocks and self.blocks['expand'] is True: self.expand = True features1=features features2=features*2 features3=features*4 features4=features*8 self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) self.scratch.activation = nn.ReLU(False) self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) self.scratch.output_conv = nn.Sequential( nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) if path: self.load(path) def forward(self, x): """Forward pass. Args: x (tensor): input data (image) Returns: tensor: depth """ if self.channels_last is True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) def fuse_model(m): prev_previous_type = nn.Identity() prev_previous_name = '' previous_type = nn.Identity() previous_name = '' for name, module in m.named_modules(): if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: # print("FUSED ", prev_previous_name, previous_name, name) torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: # print("FUSED ", prev_previous_name, previous_name) torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: # print("FUSED ", previous_name, name) # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) previous_name = name ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py ================================================ import cv2 import torch from midas.dpt_depth import DPTDepthModel from midas.midas_net import MidasNet from midas.midas_net_custom import MidasNet_small from midas.transforms import Resize, NormalizeImage, PrepareForNet from torchvision.transforms import Compose default_models = { "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", "dpt_levit_224": "weights/dpt_levit_224.pt", "dpt_large_384": "weights/dpt_large_384.pt", "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", "midas_v21_384": "weights/midas_v21_384.pt", "midas_v21_small_256": "weights/midas_v21_small_256.pt", "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", } def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): """Load the specified network. Args: device (device): the torch device used model_path (str): path to saved model model_type (str): the type of the model to be loaded optimize (bool): optimize the model to half-integer on CUDA? height (int): inference encoder image height square (bool): resize to a square resolution? Returns: The loaded network, the transform which prepares images as input to the network and the dimensions of the network input """ if "openvino" in model_type: from openvino import Core keep_aspect_ratio = not square if model_type == "dpt_beit_large_512": model = DPTDepthModel( path=model_path, backbone="beitl16_512", non_negative=True, ) net_w, net_h = 512, 512 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_beit_large_384": model = DPTDepthModel( path=model_path, backbone="beitl16_384", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_beit_base_384": model = DPTDepthModel( path=model_path, backbone="beitb16_384", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_swin2_large_384": model = DPTDepthModel( path=model_path, backbone="swin2l24_384", non_negative=True, ) net_w, net_h = 384, 384 keep_aspect_ratio = False resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_swin2_base_384": model = DPTDepthModel( path=model_path, backbone="swin2b24_384", non_negative=True, ) net_w, net_h = 384, 384 keep_aspect_ratio = False resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_swin2_tiny_256": model = DPTDepthModel( path=model_path, backbone="swin2t16_256", non_negative=True, ) net_w, net_h = 256, 256 keep_aspect_ratio = False resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_swin_large_384": model = DPTDepthModel( path=model_path, backbone="swinl12_384", non_negative=True, ) net_w, net_h = 384, 384 keep_aspect_ratio = False resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_next_vit_large_384": model = DPTDepthModel( path=model_path, backbone="next_vit_large_6m", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) elif model_type == "dpt_levit_224": model = DPTDepthModel( path=model_path, backbone="levit_384", non_negative=True, head_features_1=64, head_features_2=8, ) net_w, net_h = 224, 224 keep_aspect_ratio = False resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_large_384": model = DPTDepthModel( path=model_path, backbone="vitl16_384", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "dpt_hybrid_384": model = DPTDepthModel( path=model_path, backbone="vitb_rn50_384", non_negative=True, ) net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif model_type == "midas_v21_384": model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" normalization = NormalizeImage( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) elif model_type == "midas_v21_small_256": model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) net_w, net_h = 256, 256 resize_mode = "upper_bound" normalization = NormalizeImage( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) elif model_type == "openvino_midas_v21_small_256": ie = Core() uncompiled_model = ie.read_model(model=model_path) model = ie.compile_model(uncompiled_model, "CPU") net_w, net_h = 256, 256 resize_mode = "upper_bound" normalization = NormalizeImage( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") raise AssertionError if "openvino" not in model_type: print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) else: print("Model loaded, optimized with OpenVINO") if "openvino" in model_type: keep_aspect_ratio = False if height is not None: net_w, net_h = height, height transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode, image_interpolation_method=cv2.INTER_CUBIC, ), normalization, PrepareForNet(), ] ) if "openvino" not in model_type: model.eval() if optimize and (device == torch.device("cuda")): if "openvino" not in model_type: model = model.to(memory_format=torch.channels_last) model = model.half() else: print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") exit() if "openvino" not in model_type: model.to(device) return model, transform, net_w, net_h ================================================ FILE: modules/control/proc/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py ================================================ import numpy as np import cv2 import math def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. Args: sample (dict): sample size (tuple): image size Returns: tuple: new size """ shape = list(sample["disparity"].shape) if shape[0] >= size[0] and shape[1] >= size[1]: return sample scale = [0, 0] scale[0] = size[0] / shape[0] scale[1] = size[1] / shape[1] scale = max(scale) shape[0] = math.ceil(scale * shape[0]) shape[1] = math.ceil(scale * shape[1]) # resize sample["image"] = cv2.resize( sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method ) sample["disparity"] = cv2.resize( sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return tuple(shape) class Resize(object): """Resize sample to given size (width, height). """ def __init__( self, width, height, resize_target=True, keep_aspect_ratio=False, ensure_multiple_of=1, resize_method="lower_bound", image_interpolation_method=cv2.INTER_AREA, ): """Init. Args: width (int): desired output width height (int): desired output height resize_target (bool, optional): True: Resize the full sample (image, mask, target). False: Resize image only. Defaults to True. keep_aspect_ratio (bool, optional): True: Keep the aspect ratio of the input sample. Output sample might not have the given width and height, and resize behaviour depends on the parameter 'resize_method'. Defaults to False. ensure_multiple_of (int, optional): Output width and height is constrained to be multiple of this parameter. Defaults to 1. resize_method (str, optional): "lower_bound": Output will be at least as large as the given size. "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) "minimal": Scale as least as possible. (Output size might be smaller than given size.) Defaults to "lower_bound". """ self.__width = width self.__height = height self.__resize_target = resize_target self.__keep_aspect_ratio = keep_aspect_ratio self.__multiple_of = ensure_multiple_of self.__resize_method = resize_method self.__image_interpolation_method = image_interpolation_method def constrain_to_multiple_of(self, x, min_val=0, max_val=None): y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) if max_val is not None and y > max_val: y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) if y < min_val: y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) return y def get_size(self, width, height): # determine new height and width scale_height = self.__height / height scale_width = self.__width / width if self.__keep_aspect_ratio: if self.__resize_method == "lower_bound": # scale such that output size is lower bound if scale_width > scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "upper_bound": # scale such that output size is upper bound if scale_width < scale_height: # fit width scale_height = scale_width else: # fit height scale_width = scale_height elif self.__resize_method == "minimal": # scale as least as possbile if abs(1 - scale_width) < abs(1 - scale_height): # fit width scale_height = scale_width else: # fit height scale_width = scale_height else: raise ValueError( f"resize_method {self.__resize_method} not implemented" ) if self.__resize_method == "lower_bound": new_height = self.constrain_to_multiple_of( scale_height * height, min_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, min_val=self.__width ) elif self.__resize_method == "upper_bound": new_height = self.constrain_to_multiple_of( scale_height * height, max_val=self.__height ) new_width = self.constrain_to_multiple_of( scale_width * width, max_val=self.__width ) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) else: raise ValueError(f"resize_method {self.__resize_method} not implemented") return (new_width, new_height) def __call__(self, sample): width, height = self.get_size( sample["image"].shape[1], sample["image"].shape[0] ) # resize sample sample["image"] = cv2.resize( sample["image"], (width, height), interpolation=self.__image_interpolation_method, ) if self.__resize_target: if "disparity" in sample: sample["disparity"] = cv2.resize( sample["disparity"], (width, height), interpolation=cv2.INTER_NEAREST, ) if "depth" in sample: sample["depth"] = cv2.resize( sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST ) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST, ) sample["mask"] = sample["mask"].astype(bool) return sample class NormalizeImage(object): """Normlize image by given mean and std. """ def __init__(self, mean, std): self.__mean = mean self.__std = std def __call__(self, sample): sample["image"] = (sample["image"] - self.__mean) / self.__std return sample class PrepareForNet(object): """Prepare sample for usage as network input. """ def __init__(self): pass def __call__(self, sample): image = np.transpose(sample["image"], (2, 0, 1)) sample["image"] = np.ascontiguousarray(image).astype(np.float32) if "mask" in sample: sample["mask"] = sample["mask"].astype(np.float32) sample["mask"] = np.ascontiguousarray(sample["mask"]) if "disparity" in sample: disparity = sample["disparity"].astype(np.float32) sample["disparity"] = np.ascontiguousarray(disparity) if "depth" in sample: depth = sample["depth"].astype(np.float32) sample["depth"] = np.ascontiguousarray(depth) return sample ================================================ FILE: modules/control/proc/zoe/zoedepth/models/builder.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat from importlib import import_module from .depth_model import DepthModel def build_model(config) -> DepthModel: """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. This function should be used to construct models for training and evaluation. Args: config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. Returns: torch.nn.Module: Model corresponding to name and version as specified in config """ module_name = f"zoedepth.models.{config.model}" try: module = import_module(module_name) except ModuleNotFoundError as e: # print the original error message print(e) raise ValueError( f"Model {config.model} not found. Refer above error for details.") from e try: get_version = module.get_version except AttributeError as e: raise ValueError( f"Model {config.model} has no get_version function.") from e return get_version(config.version_name).build_from_config(config) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/depth_model.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms import PIL.Image from PIL import Image from typing import Union class DepthModel(nn.Module): def __init__(self, device='cpu'): super().__init__() self.device = device def to(self, device) -> nn.Module: self.device = device return super().to(device) def forward(self, x, *args, **kwargs): raise NotImplementedError def _infer(self, x: torch.Tensor): """ Inference interface for the model Args: x (torch.Tensor): input tensor of shape (b, c, h, w) Returns: torch.Tensor: output tensor of shape (b, 1, h, w) """ return self(x)['metric_depth'] def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: """ Inference interface for the model with padding augmentation Padding augmentation fixes the boundary artifacts in the output depth map. Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. This augmentation pads the input image and crops the prediction back to the original size / view. Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. Args: x (torch.Tensor): input tensor of shape (b, c, h, w) pad_input (bool, optional): whether to pad the input or not. Defaults to True. fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. padding_mode (str, optional): padding mode. Defaults to "reflect". Returns: torch.Tensor: output tensor of shape (b, 1, h, w) """ # assert x is nchw and c = 3 assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) if pad_input: assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" pad_h = int(np.sqrt(x.shape[2]/2) * fh) pad_w = int(np.sqrt(x.shape[3]/2) * fw) padding = [pad_w, pad_w] if pad_h > 0: padding += [pad_h, pad_h] x = F.pad(x, padding, mode=padding_mode, **kwargs) out = self._infer(x) if out.shape[-2:] != x.shape[-2:]: out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) if pad_input: # crop to the original size, handling the case where pad_h and pad_w is 0 if pad_h > 0: out = out[:, :, pad_h:-pad_h,:] if pad_w > 0: out = out[:, :, :, pad_w:-pad_w] return out def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: """ Inference interface for the model with horizontal flip augmentation Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. Args: x (torch.Tensor): input tensor of shape (b, c, h, w) pad_input (bool, optional): whether to use padding augmentation. Defaults to True. Returns: torch.Tensor: output tensor of shape (b, 1, h, w) """ # infer with horizontal flip and average out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) out = (out + torch.flip(out_flip, dims=[3])) / 2 return out def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: """ Inference interface for the model Args: x (torch.Tensor): input tensor of shape (b, c, h, w) pad_input (bool, optional): whether to use padding augmentation. Defaults to True. with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. Returns: torch.Tensor: output tensor of shape (b, 1, h, w) """ if with_flip_aug: return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) else: return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: """ Inference interface for the model for PIL image Args: pil_img (PIL.Image.Image): input PIL image pad_input (bool, optional): whether to use padding augmentation. Defaults to True. with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". """ x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) if output_type == "numpy": return out_tensor.squeeze().cpu().numpy() elif output_type == "pil": # uint16 is required for depth pil image out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) return Image.fromarray(out_16bit_numpy) elif output_type == "tensor": return out_tensor.squeeze().cpu() else: raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") ================================================ FILE: modules/control/proc/zoe/zoedepth/models/layers/__init__.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat ================================================ FILE: modules/control/proc/zoe/zoedepth/models/layers/attractor.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch import torch.nn as nn @torch.jit.script def exp_attractor(dx, alpha: float = 300, gamma: int = 2): """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor Args: dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. Returns: torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc """ return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) @torch.jit.script def inv_attractor(dx, alpha: float = 300, gamma: int = 2): """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center This is the default one according to the accompanying paper. Args: dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. Returns: torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc """ return dx.div(1+alpha*dx.pow(gamma)) class AttractorLayer(nn.Module): def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): """ Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) """ super().__init__() self.n_attractors = n_attractors self.n_bins = n_bins self.min_depth = min_depth self.max_depth = max_depth self.alpha = alpha self.gamma = gamma self.kind = kind self.attractor_type = attractor_type self.memory_efficient = memory_efficient self._net = nn.Sequential( nn.Conv2d(in_features, mlp_dim, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm nn.ReLU(inplace=True) ) def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): """ Args: x (torch.Tensor) : feature block; shape - n, c, h, w b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w Returns: tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w """ if prev_b_embedding is not None: if interpolate: prev_b_embedding = nn.functional.interpolate( prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) x = x + prev_b_embedding A = self._net(x) eps = 1e-3 A = A + eps n, c, h, w = A.shape A = A.view(n, self.n_attractors, 2, h, w) A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w A_normed = A[:, :, 0, ...] # n, na, h, w b_prev = nn.functional.interpolate( b_prev, (h, w), mode='bilinear', align_corners=True) b_centers = b_prev if self.attractor_type == 'exp': dist = exp_attractor else: dist = inv_attractor if not self.memory_efficient: func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] # .shape N, nbins, h, w delta_c = func(dist(A_normed.unsqueeze( 2) - b_centers.unsqueeze(1)), dim=1) else: delta_c = torch.zeros_like(b_centers, device=b_centers.device) for i in range(self.n_attractors): # .shape N, nbins, h, w delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) if self.kind == 'mean': delta_c = delta_c / self.n_attractors b_new_centers = b_centers + delta_c B_centers = (self.max_depth - self.min_depth) * \ b_new_centers + self.min_depth B_centers, _ = torch.sort(B_centers, dim=1) B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) return b_new_centers, B_centers class AttractorLayerUnnormed(nn.Module): def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): """ Attractor layer for bin centers. Bin centers are unbounded """ super().__init__() self.n_attractors = n_attractors self.n_bins = n_bins self.min_depth = min_depth self.max_depth = max_depth self.alpha = alpha self.gamma = gamma self.kind = kind self.attractor_type = attractor_type self.memory_efficient = memory_efficient self._net = nn.Sequential( nn.Conv2d(in_features, mlp_dim, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), nn.Softplus() ) def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): """ Args: x (torch.Tensor) : feature block; shape - n, c, h, w b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w Returns: tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version """ if prev_b_embedding is not None: if interpolate: prev_b_embedding = nn.functional.interpolate( prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) x = x + prev_b_embedding A = self._net(x) n, c, h, w = A.shape b_prev = nn.functional.interpolate( b_prev, (h, w), mode='bilinear', align_corners=True) b_centers = b_prev if self.attractor_type == 'exp': dist = exp_attractor else: dist = inv_attractor if not self.memory_efficient: func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] # .shape N, nbins, h, w delta_c = func( dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) else: delta_c = torch.zeros_like(b_centers, device=b_centers.device) for i in range(self.n_attractors): delta_c += dist(A[:, i, ...].unsqueeze(1) - b_centers) # .shape N, nbins, h, w if self.kind == 'mean': delta_c = delta_c / self.n_attractors b_new_centers = b_centers + delta_c B_centers = b_new_centers return b_new_centers, B_centers ================================================ FILE: modules/control/proc/zoe/zoedepth/models/layers/dist_layers.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch import torch.nn as nn def log_binom(n, k, eps=1e-7): """ log(nCk) using stirling approximation """ n = n + eps k = k + eps return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) class LogBinomial(nn.Module): def __init__(self, n_classes=256, act=torch.softmax): """Compute log binomial distribution for n_classes Args: n_classes (int, optional): number of output classes. Defaults to 256. """ super().__init__() self.K = n_classes self.act = act self.register_buffer('k_idx', torch.arange( 0, n_classes).view(1, -1, 1, 1)) self.register_buffer('K_minus_1', torch.Tensor( [self.K-1]).view(1, -1, 1, 1)) def forward(self, x, t=1., eps=1e-4): """Compute log binomial distribution for x Args: x (torch.Tensor - NCHW): probabilities t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. eps (float, optional): Small number for numerical stability. Defaults to 1e-4. Returns: torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) """ if x.ndim == 3: x = x.unsqueeze(1) # make it nchw one_minus_x = torch.clamp(1 - x, eps, 1) x = torch.clamp(x, eps, 1) y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) return self.act(y/t, dim=1) class ConditionalLogBinomial(nn.Module): def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): """Conditional Log Binomial distribution Args: in_features (int): number of input channels in main feature condition_dim (int): number of input channels in condition feature n_classes (int, optional): Number of classes. Defaults to 256. bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. p_eps (float, optional): small eps value. Defaults to 1e-4. max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. """ super().__init__() self.p_eps = p_eps self.max_temp = max_temp self.min_temp = min_temp self.log_binomial_transform = LogBinomial(n_classes, act=act) bottleneck = (in_features + condition_dim) // bottleneck_factor self.mlp = nn.Sequential( nn.Conv2d(in_features + condition_dim, bottleneck, kernel_size=1, stride=1, padding=0), nn.GELU(), # 2 for p linear norm, 2 for t linear norm nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), nn.Softplus() ) def forward(self, x, cond): """Forward pass Args: x (torch.Tensor - NCHW): Main feature cond (torch.Tensor - NCHW): condition feature Returns: torch.Tensor: Output log binomial distribution """ pt = self.mlp(torch.concat((x, cond), dim=1)) p, t = pt[:, :2, ...], pt[:, 2:, ...] p = p + self.p_eps p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) t = t + self.p_eps t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) t = t.unsqueeze(1) t = (self.max_temp - self.min_temp) * t + self.min_temp return self.log_binomial_transform(p, t) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/layers/localbins_layers.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch import torch.nn as nn class SeedBinRegressor(nn.Module): def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. Args: in_features (int): input channels n_bins (int, optional): Number of bin centers. Defaults to 16. mlp_dim (int, optional): Hidden dimension. Defaults to 256. min_depth (float, optional): Min depth value. Defaults to 1e-3. max_depth (float, optional): Max depth value. Defaults to 10. """ super().__init__() self.version = "1_1" self.min_depth = min_depth self.max_depth = max_depth self._net = nn.Sequential( nn.Conv2d(in_features, mlp_dim, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), nn.ReLU(inplace=True) ) def forward(self, x): """ Returns tensor of bin_width vectors (centers). One vector b for every pixel """ B = self._net(x) eps = 1e-3 B = B + eps B_widths_normed = B / B.sum(dim=1, keepdim=True) B_widths = (self.max_depth - self.min_depth) * \ B_widths_normed # .shape NCHW # pad has the form (left, right, top, bottom, front, back) B_widths = nn.functional.pad( B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) return B_widths_normed, B_centers class SeedBinRegressorUnnormed(nn.Module): def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): """Bin center regressor network. Bin centers are unbounded Args: in_features (int): input channels n_bins (int, optional): Number of bin centers. Defaults to 16. mlp_dim (int, optional): Hidden dimension. Defaults to 256. min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) """ super().__init__() self.version = "1_1" self._net = nn.Sequential( nn.Conv2d(in_features, mlp_dim, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), nn.Softplus() ) def forward(self, x): """ Returns tensor of bin_width vectors (centers). One vector b for every pixel """ B_centers = self._net(x) return B_centers, B_centers class Projector(nn.Module): def __init__(self, in_features, out_features, mlp_dim=128): """Projector MLP Args: in_features (int): input channels out_features (int): output channels mlp_dim (int, optional): hidden dimension. Defaults to 128. """ super().__init__() self._net = nn.Sequential( nn.Conv2d(in_features, mlp_dim, 1, 1, 0), nn.ReLU(inplace=True), nn.Conv2d(mlp_dim, out_features, 1, 1, 0), ) def forward(self, x): return self._net(x) class LinearSplitter(nn.Module): def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): super().__init__() self.prev_nbins = prev_nbins self.split_factor = split_factor self.min_depth = min_depth self.max_depth = max_depth self._net = nn.Sequential( nn.Conv2d(in_features, mlp_dim, 1, 1, 0), nn.GELU(), nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), nn.ReLU() ) def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): """ x : feature block; shape - n, c, h, w b_prev : previous bin widths normed; shape - n, prev_nbins, h, w """ if prev_b_embedding is not None: if interpolate: prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) x = x + prev_b_embedding S = self._net(x) eps = 1e-3 S = S + eps n, c, h, w = S.shape S = S.view(n, self.prev_nbins, self.split_factor, h, w) S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees # print(b_prev.shape, S_normed.shape) # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) b = b_prev.unsqueeze(2) * S_normed b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w # calculate bin centers for loss calculation B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W # pad has the form (left, right, top, bottom, front, back) B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) return b, B_centers ================================================ FILE: modules/control/proc/zoe/zoedepth/models/layers/patch_transformer.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch import torch.nn as nn class PatchTransformerEncoder(nn.Module): def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): """ViT-like transformer block Args: in_channels (int): Input channels patch_size (int, optional): patch size. Defaults to 10. embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. num_heads (int, optional): number of attention heads. Defaults to 4. use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. """ super(PatchTransformerEncoder, self).__init__() self.use_class_token = use_class_token encoder_layers = nn.TransformerEncoderLayer( embedding_dim, num_heads, dim_feedforward=1024) self.transformer_encoder = nn.TransformerEncoder( encoder_layers, num_layers=4) # takes shape S,N,E self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size, padding=0) def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): """Generate positional encodings Args: sequence_length (int): Sequence length embedding_dim (int): Embedding dimension Returns: torch.Tensor SBE: Positional encodings """ position = torch.arange( 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) index = torch.arange( 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) pos_encoding = position * div_term pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) return pos_encoding def forward(self, x): """Forward pass Args: x (torch.Tensor - NCHW): Input feature tensor Returns: torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim """ embeddings = self.embedding_convPxP(x).flatten( 2) # .shape = n,c,s = n, embedding_dim, s if self.use_class_token: # extra special token at start ? embeddings = nn.functional.pad(embeddings, (1, 0)) # change to S,N,E format required by transformer embeddings = embeddings.permute(2, 0, 1) S, N, E = embeddings.shape embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) x = self.transformer_encoder(embeddings) # .shape = S, N, E return x ================================================ FILE: modules/control/proc/zoe/zoedepth/models/model_io.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import torch def load_state_dict(model, state_dict): """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. DataParallel prefixes state_dict keys with 'module.' when saving. If the model is not a DataParallel model but the state_dict is, then prefixes are removed. If the model is a DataParallel model but the state_dict is not, then prefixes are added. """ state_dict = state_dict.get('model', state_dict) # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' do_prefix = isinstance( model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) state = {} for k, v in state_dict.items(): if k.startswith('module.') and not do_prefix: k = k[7:] if not k.startswith('module.') and do_prefix: k = 'module.' + k state[k] = v model.load_state_dict(state) print("Loaded successfully") return model def load_wts(model, checkpoint_path): ckpt = torch.load(checkpoint_path, map_location='cpu') return load_state_dict(model, ckpt) def load_state_dict_from_url(model, url, **kwargs): state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) return load_state_dict(model, state_dict) def load_state_from_resource(model, resource: str): """Loads weights to the model from a given resource. A resource can be of following types: 1. URL. Prefixed with "url::" e.g. url::http(s)://url.resource.com/ckpt.pt 2. Local path. Prefixed with "local::" e.g. local::/path/to/ckpt.pt Args: model (torch.nn.Module): Model resource (str): resource string Returns: torch.nn.Module: Model with loaded weights """ print(f"Using pretrained resource {resource}") if resource.startswith('url::'): url = resource.split('url::')[1] return load_state_dict_from_url(model, url, progress=True) elif resource.startswith('local::'): path = resource.split('local::')[1] return load_wts(model, path) else: raise ValueError("Invalid resource type, only url:: and local:: are supported") ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth/__init__.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat from .zoedepth_v1 import ZoeDepth all_versions = { "v1": ZoeDepth, } get_version = lambda v : all_versions[v] ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth.json ================================================ { "model": { "name": "ZoeDepth", "version_name": "v1", "n_bins": 64, "bin_embedding_dim": 128, "bin_centers_type": "softplus", "n_attractors":[16, 8, 4, 1], "attractor_alpha": 1000, "attractor_gamma": 2, "attractor_kind" : "mean", "attractor_type" : "inv", "midas_model_type" : "DPT_BEiT_L_384", "min_temp": 0.0212, "max_temp": 50.0, "output_distribution": "logbinomial", "memory_efficient": true, "inverse_midas": false, "img_size": [384, 512] }, "train": { "train_midas": true, "use_pretrained_midas": true, "trainer": "zoedepth", "epochs": 5, "bs": 16, "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, "same_lr": false, "w_si": 1, "w_domain": 0.2, "w_reg": 0, "w_grad": 0, "avoid_boundary": false, "random_crop": false, "input_width": 640, "input_height": 480, "midas_lr_factor": 1, "encoder_lr_factor":10, "pos_enc_lr_factor":10, "freeze_midas_bn": true }, "infer":{ "train_midas": false, "use_pretrained_midas": false, "pretrained_resource" : null, "force_keep_ar": true }, "eval":{ "train_midas": false, "use_pretrained_midas": false, "pretrained_resource" : null } } ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json ================================================ { "model": { "bin_centers_type": "normed", "img_size": [384, 768] }, "train": { }, "infer":{ "train_midas": false, "use_pretrained_midas": false, "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", "force_keep_ar": true }, "eval":{ "train_midas": false, "use_pretrained_midas": false, "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" } } ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth/zoedepth_v1.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import itertools import torch import torch.nn as nn from ..depth_model import DepthModel from ..base_models.midas import MidasCore from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed from ..layers.dist_layers import ConditionalLogBinomial from ..layers.localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed) from ..model_io import load_state_from_resource class ZoeDepth(DepthModel): def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, n_attractors=None, attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): """ZoeDepth model. This is the version of ZoeDepth that has a single metric head Args: core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features n_bins (int, optional): Number of bin centers. Defaults to 64. bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. """ if n_attractors is None: n_attractors = [16, 8, 4, 1] super().__init__() self.core = core self.max_depth = max_depth self.min_depth = min_depth self.min_temp = min_temp self.bin_centers_type = bin_centers_type self.midas_lr_factor = midas_lr_factor self.encoder_lr_factor = encoder_lr_factor self.pos_enc_lr_factor = pos_enc_lr_factor self.train_midas = train_midas self.inverse_midas = inverse_midas if self.encoder_lr_factor <= 0: self.core.freeze_encoder( freeze_rel_pos=self.pos_enc_lr_factor <= 0) N_MIDAS_OUT = 32 btlnck_features = self.core.output_channels[0] num_out_features = self.core.output_channels[1:] self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) # btlnck conv if bin_centers_type == "normed": SeedBinRegressorLayer = SeedBinRegressor Attractor = AttractorLayer elif bin_centers_type == "softplus": SeedBinRegressorLayer = SeedBinRegressorUnnormed Attractor = AttractorLayerUnnormed elif bin_centers_type == "hybrid1": SeedBinRegressorLayer = SeedBinRegressor Attractor = AttractorLayerUnnormed elif bin_centers_type == "hybrid2": SeedBinRegressorLayer = SeedBinRegressorUnnormed Attractor = AttractorLayer else: raise ValueError( "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") self.seed_bin_regressor = SeedBinRegressorLayer( btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) self.seed_projector = Projector(btlnck_features, bin_embedding_dim) self.projectors = nn.ModuleList([ Projector(num_out, bin_embedding_dim) for num_out in num_out_features ]) self.attractors = nn.ModuleList([ Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) for i in range(len(num_out_features)) ]) last_in = N_MIDAS_OUT + 1 # +1 for relative depth # use log binomial instead of softmax self.conditional_log_binomial = ConditionalLogBinomial( last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): """ Args: x (torch.Tensor): Input image tensor of shape (B, C, H, W) return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. Returns: dict: Dictionary containing the following keys: - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True """ b, c, h, w = x.shape # print("input shape ", x.shape) self.orig_input_width = w self.orig_input_height = h rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) # print("output shapes", rel_depth.shape, out.shape) outconv_activation = out[0] btlnck = out[1] x_blocks = out[2:] x_d0 = self.conv2(btlnck) x = x_d0 _, seed_b_centers = self.seed_bin_regressor(x) if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': b_prev = (seed_b_centers - self.min_depth) / \ (self.max_depth - self.min_depth) else: b_prev = seed_b_centers prev_b_embedding = self.seed_projector(x) # unroll this loop for better performance for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): b_embedding = projector(x) b, b_centers = attractor( b_embedding, b_prev, prev_b_embedding, interpolate=True) b_prev = b.clone() prev_b_embedding = b_embedding.clone() last = outconv_activation if self.inverse_midas: # invert depth followed by normalization rel_depth = 1.0 / (rel_depth + 1e-6) rel_depth = (rel_depth - rel_depth.min()) / \ (rel_depth.max() - rel_depth.min()) # concat rel depth with last. First interpolate rel depth to last size rel_cond = rel_depth.unsqueeze(1) rel_cond = nn.functional.interpolate( rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) last = torch.cat([last, rel_cond], dim=1) b_embedding = nn.functional.interpolate( b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) x = self.conditional_log_binomial(last, b_embedding) # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor # print(x.shape, b_centers.shape) b_centers = nn.functional.interpolate( b_centers, x.shape[-2:], mode='bilinear', align_corners=True) out = torch.sum(x * b_centers, dim=1, keepdim=True) # Structure output dict output = dict(metric_depth=out) if return_final_centers or return_probs: output['bin_centers'] = b_centers if return_probs: output['probs'] = x return output def get_lr_params(self, lr): """ Learning rate configuration for different layers of the model Args: lr (float) : Base learning rate Returns: list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. """ param_conf = [] if self.train_midas: if self.encoder_lr_factor > 0: param_conf.append({'params': self.core.get_enc_params_except_rel_pos( ), 'lr': lr / self.encoder_lr_factor}) if self.pos_enc_lr_factor > 0: param_conf.append( {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) midas_params = self.core.core.scratch.parameters() midas_lr_factor = self.midas_lr_factor param_conf.append( {'params': midas_params, 'lr': lr / midas_lr_factor}) remaining_modules = [] for name, child in self.named_children(): if name != 'core': remaining_modules.append(child) remaining_params = itertools.chain( *[child.parameters() for child in remaining_modules]) param_conf.append({'params': remaining_params, 'lr': lr}) return param_conf @staticmethod def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) model = ZoeDepth(core, **kwargs) if pretrained_resource: assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" model = load_state_from_resource(model, pretrained_resource) return model @staticmethod def build_from_config(config): return ZoeDepth.build(**config) ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth_nk/__init__.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat from .zoedepth_nk_v1 import ZoeDepthNK all_versions = { "v1": ZoeDepthNK, } get_version = lambda v : all_versions[v] ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json ================================================ { "model": { "name": "ZoeDepthNK", "version_name": "v1", "bin_conf" : [ { "name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0 }, { "name": "kitti", "n_bins": 64, "min_depth": 1e-3, "max_depth": 80.0 } ], "bin_embedding_dim": 128, "bin_centers_type": "softplus", "n_attractors":[16, 8, 4, 1], "attractor_alpha": 1000, "attractor_gamma": 2, "attractor_kind" : "mean", "attractor_type" : "inv", "min_temp": 0.0212, "max_temp": 50.0, "memory_efficient": true, "midas_model_type" : "DPT_BEiT_L_384", "img_size": [384, 512] }, "train": { "train_midas": true, "use_pretrained_midas": true, "trainer": "zoedepth_nk", "epochs": 5, "bs": 16, "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, "same_lr": false, "w_si": 1, "w_domain": 100, "avoid_boundary": false, "random_crop": false, "input_width": 640, "input_height": 480, "w_grad": 0, "w_reg": 0, "midas_lr_factor": 10, "encoder_lr_factor":10, "pos_enc_lr_factor":10 }, "infer": { "train_midas": false, "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", "use_pretrained_midas": false, "force_keep_ar": true }, "eval": { "train_midas": false, "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", "use_pretrained_midas": false } } ================================================ FILE: modules/control/proc/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import itertools import torch import torch.nn as nn from ..depth_model import DepthModel from ..base_models.midas import MidasCore from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed from ..layers.dist_layers import ConditionalLogBinomial from ..layers.localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed) from ..layers.patch_transformer import PatchTransformerEncoder from ..model_io import load_state_from_resource class ZoeDepthNK(DepthModel): def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, n_attractors=None, attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, memory_efficient=False, train_midas=True, is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. Args: core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) The length of this list determines the number of metric heads. bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. """ if n_attractors is None: n_attractors = [16, 8, 4, 1] super().__init__() self.core = core self.bin_conf = bin_conf self.min_temp = min_temp self.max_temp = max_temp self.memory_efficient = memory_efficient self.train_midas = train_midas self.is_midas_pretrained = is_midas_pretrained self.midas_lr_factor = midas_lr_factor self.encoder_lr_factor = encoder_lr_factor self.pos_enc_lr_factor = pos_enc_lr_factor self.inverse_midas = inverse_midas N_MIDAS_OUT = 32 btlnck_features = self.core.output_channels[0] num_out_features = self.core.output_channels[1:] # self.scales = [16, 8, 4, 2] # spatial scale factors self.conv2 = nn.Conv2d( btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) # Transformer classifier on the bottleneck self.patch_transformer = PatchTransformerEncoder( btlnck_features, 1, 128, use_class_token=True) self.mlp_classifier = nn.Sequential( nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 2) ) if bin_centers_type == "normed": SeedBinRegressorLayer = SeedBinRegressor Attractor = AttractorLayer elif bin_centers_type == "softplus": SeedBinRegressorLayer = SeedBinRegressorUnnormed Attractor = AttractorLayerUnnormed elif bin_centers_type == "hybrid1": SeedBinRegressorLayer = SeedBinRegressor Attractor = AttractorLayerUnnormed elif bin_centers_type == "hybrid2": SeedBinRegressorLayer = SeedBinRegressorUnnormed Attractor = AttractorLayer else: raise ValueError( "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") self.bin_centers_type = bin_centers_type # We have bins for each bin conf. # Create a map (ModuleDict) of 'name' -> seed_bin_regressor self.seed_bin_regressors = nn.ModuleDict( {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) for conf in bin_conf} ) self.seed_projector = Projector( btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) self.projectors = nn.ModuleList([ Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) for num_out in num_out_features ]) # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) self.attractors = nn.ModuleDict( {conf['name']: nn.ModuleList([ Attractor(bin_embedding_dim, n_attractors[i], mlp_dim=bin_embedding_dim, alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type, memory_efficient=memory_efficient, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) for i in range(len(n_attractors)) ]) for conf in bin_conf} ) last_in = N_MIDAS_OUT # conditional log binomial for each bin conf self.conditional_log_binomial = nn.ModuleDict( {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) for conf in bin_conf} ) def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): """ Args: x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. denorm (bool, optional): Whether to denormalize the input image. Defaults to False. return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. Returns: dict: Dictionary of outputs with keys: - "rel_depth": Relative depth map of shape (B, 1, H, W) - "metric_depth": Metric depth map of shape (B, 1, H, W) - "domain_logits": Domain logits of shape (B, 2) - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True """ b, c, h, w = x.shape self.orig_input_width = w self.orig_input_height = h rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) outconv_activation = out[0] btlnck = out[1] x_blocks = out[2:] x_d0 = self.conv2(btlnck) x = x_d0 # Predict which path to take embedding = self.patch_transformer(x)[0] # N, E domain_logits = self.mlp_classifier(embedding) # N, 2 domain_vote = torch.softmax(domain_logits.sum( dim=0, keepdim=True), dim=-1) # 1, 2 # Get the path bin_conf_name = ["nyu", "kitti"][torch.argmax( domain_vote, dim=-1).squeeze().item()] try: conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] except IndexError as e: raise ValueError(f"bin_conf_name {bin_conf_name} not found in bin_confs") from e min_depth = conf['min_depth'] max_depth = conf['max_depth'] seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] _, seed_b_centers = seed_bin_regressor(x) if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) else: b_prev = seed_b_centers prev_b_embedding = self.seed_projector(x) attractors = self.attractors[bin_conf_name] for projector, attractor, x in zip(self.projectors, attractors, x_blocks): b_embedding = projector(x) b, b_centers = attractor( b_embedding, b_prev, prev_b_embedding, interpolate=True) b_prev = b prev_b_embedding = b_embedding last = outconv_activation b_centers = nn.functional.interpolate( b_centers, last.shape[-2:], mode='bilinear', align_corners=True) b_embedding = nn.functional.interpolate( b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) clb = self.conditional_log_binomial[bin_conf_name] x = clb(last, b_embedding) # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor # print(x.shape, b_centers.shape) # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) out = torch.sum(x * b_centers, dim=1, keepdim=True) output = dict(domain_logits=domain_logits, metric_depth=out) if return_final_centers or return_probs: output['bin_centers'] = b_centers if return_probs: output['probs'] = x return output def get_lr_params(self, lr): """ Learning rate configuration for different layers of the model Args: lr (float) : Base learning rate Returns: list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. """ param_conf = [] if self.train_midas: def get_rel_pos_params(): for name, p in self.core.core.pretrained.named_parameters(): if "relative_position" in name: yield p def get_enc_params_except_rel_pos(): for name, p in self.core.core.pretrained.named_parameters(): if "relative_position" not in name: yield p encoder_params = get_enc_params_except_rel_pos() rel_pos_params = get_rel_pos_params() midas_params = self.core.core.scratch.parameters() midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 param_conf.extend([ {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, {'params': midas_params, 'lr': lr / midas_lr_factor} ]) remaining_modules = [] for name, child in self.named_children(): if name != 'core': remaining_modules.append(child) remaining_params = itertools.chain( *[child.parameters() for child in remaining_modules]) param_conf.append({'params': remaining_params, 'lr': lr}) return param_conf def get_conf_parameters(self, conf_name): """ Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration """ params = [] for _name, child in self.named_children(): if isinstance(child, nn.ModuleDict): for bin_conf_name, module in child.items(): if bin_conf_name == conf_name: params += list(module.parameters()) return params def freeze_conf(self, conf_name): """ Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration """ for p in self.get_conf_parameters(conf_name): p.requires_grad = False def unfreeze_conf(self, conf_name): """ Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration """ for p in self.get_conf_parameters(conf_name): p.requires_grad = True def freeze_all_confs(self): """ Freezes all the parameters of all the ModuleDicts children """ for _name, child in self.named_children(): if isinstance(child, nn.ModuleDict): for _bin_conf_name, module in child.items(): for p in module.parameters(): p.requires_grad = False @staticmethod def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) model = ZoeDepthNK(core, **kwargs) if pretrained_resource: assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" model = load_state_from_resource(model, pretrained_resource) return model @staticmethod def build_from_config(config): return ZoeDepthNK.build(**config) ================================================ FILE: modules/control/proc/zoe/zoedepth/utils/__init__.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat ================================================ FILE: modules/control/proc/zoe/zoedepth/utils/arg_utils.py ================================================ def infer_type(x): # hacky way to infer type from string args if not isinstance(x, str): return x try: x = int(x) return x except ValueError: pass try: x = float(x) return x except ValueError: pass return x def parse_unknown(unknown_args): clean = [] for a in unknown_args: if "=" in a: k, v = a.split("=") clean.extend([k, v]) else: clean.append(a) keys = clean[::2] values = clean[1::2] return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} ================================================ FILE: modules/control/proc/zoe/zoedepth/utils/config.py ================================================ # MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat import json import os from .easydict import EasyDict as edict from .arg_utils import infer_type import pathlib import platform ROOT = pathlib.Path(__file__).parent.parent.resolve() HOME_DIR = os.path.expanduser("~") COMMON_CONFIG = { "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), "project": "ZoeDepth", "tags": '', "notes": "", "gpu": None, "root": ".", "uid": None, "print_losses": False } DATASETS_CONFIG = { "kitti": { "dataset": "kitti", "min_depth": 0.001, "max_depth": 80, "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", "input_height": 352, "input_width": 1216, # 704 "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", "min_depth_eval": 1e-3, "max_depth_eval": 80, "do_random_rotate": True, "degree": 1.0, "do_kb_crop": True, "garg_crop": True, "eigen_crop": False, "use_right": False }, "kitti_test": { "dataset": "kitti", "min_depth": 0.001, "max_depth": 80, "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", "input_height": 352, "input_width": 1216, "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", "min_depth_eval": 1e-3, "max_depth_eval": 80, "do_random_rotate": False, "degree": 1.0, "do_kb_crop": True, "garg_crop": True, "eigen_crop": False, "use_right": False }, "nyu": { "dataset": "nyu", "avoid_boundary": False, "min_depth": 1e-3, # originally 0.1 "max_depth": 10, "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", "input_height": 480, "input_width": 640, "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", "min_depth_eval": 1e-3, "max_depth_eval": 10, "min_depth_diff": -10, "max_depth_diff": 10, "do_random_rotate": True, "degree": 1.0, "do_kb_crop": False, "garg_crop": False, "eigen_crop": True }, "ibims": { "dataset": "ibims", "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), "eigen_crop": True, "garg_crop": False, "do_kb_crop": False, "min_depth_eval": 0, "max_depth_eval": 10, "min_depth": 1e-3, "max_depth": 10 }, "sunrgbd": { "dataset": "sunrgbd", "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), "eigen_crop": True, "garg_crop": False, "do_kb_crop": False, "min_depth_eval": 0, "max_depth_eval": 8, "min_depth": 1e-3, "max_depth": 10 }, "diml_indoor": { "dataset": "diml_indoor", "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), "eigen_crop": True, "garg_crop": False, "do_kb_crop": False, "min_depth_eval": 0, "max_depth_eval": 10, "min_depth": 1e-3, "max_depth": 10 }, "diml_outdoor": { "dataset": "diml_outdoor", "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), "eigen_crop": False, "garg_crop": True, "do_kb_crop": False, "min_depth_eval": 2, "max_depth_eval": 80, "min_depth": 1e-3, "max_depth": 80 }, "diode_indoor": { "dataset": "diode_indoor", "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), "eigen_crop": True, "garg_crop": False, "do_kb_crop": False, "min_depth_eval": 1e-3, "max_depth_eval": 10, "min_depth": 1e-3, "max_depth": 10 }, "diode_outdoor": { "dataset": "diode_outdoor", "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), "eigen_crop": False, "garg_crop": True, "do_kb_crop": False, "min_depth_eval": 1e-3, "max_depth_eval": 80, "min_depth": 1e-3, "max_depth": 80 }, "hypersim_test": { "dataset": "hypersim_test", "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), "eigen_crop": True, "garg_crop": False, "do_kb_crop": False, "min_depth_eval": 1e-3, "max_depth_eval": 80, "min_depth": 1e-3, "max_depth": 10 }, "vkitti": { "dataset": "vkitti", "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), "eigen_crop": False, "garg_crop": True, "do_kb_crop": True, "min_depth_eval": 1e-3, "max_depth_eval": 80, "min_depth": 1e-3, "max_depth": 80 }, "vkitti2": { "dataset": "vkitti2", "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), "eigen_crop": False, "garg_crop": True, "do_kb_crop": True, "min_depth_eval": 1e-3, "max_depth_eval": 80, "min_depth": 1e-3, "max_depth": 80, }, "ddad": { "dataset": "ddad", "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), "eigen_crop": False, "garg_crop": True, "do_kb_crop": True, "min_depth_eval": 1e-3, "max_depth_eval": 80, "min_depth": 1e-3, "max_depth": 80, }, } ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR COMMON_TRAINING_CONFIG = { "dataset": "nyu", "distributed": True, "workers": 16, "clip_grad": 0.1, "use_shared_dict": False, "shared_dict": None, "use_amp": False, "aug": True, "random_crop": False, "random_translate": False, "translate_prob": 0.2, "max_translation": 100, "validate_every": 0.25, "log_images_every": 0.1, "prefetch": False, } def flatten(config, except_keys=('bin_conf')): def recurse(inp): if isinstance(inp, dict): for key, value in inp.items(): if key in except_keys: yield (key, value) if isinstance(value, dict): yield from recurse(value) else: yield (key, value) return dict(list(recurse(config))) def split_combined_args(kwargs): """Splits the arguments that are combined with '__' into multiple arguments. Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001' Args: kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. Returns: dict: Parsed dict with the combined arguments split into individual key-value pairs. """ new_kwargs = dict(kwargs) for key, value in kwargs.items(): if key.startswith("__"): keys = key.split("__")[1:] values = value.split(";") assert len(keys) == len( values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" for k, v in zip(keys, values): new_kwargs[k] = v return new_kwargs def parse_list(config, key, dtype=int): """Parse a list of values for the key if the value is a string. The values are separated by a comma. Modifies the config in place. """ if key in config: if isinstance(config[key], str): config[key] = list(map(dtype, config[key].split(','))) assert isinstance(config[key], list) and all(isinstance(e, dtype) for e in config[key] ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." def get_model_config(model_name, model_version=None): """Find and parse the .json config file for the model. Args: model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. Returns: easydict: the config dictionary for the model. """ config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" config_file = os.path.join(ROOT, "models", model_name, config_fname) if not os.path.exists(config_file): return None with open(config_file, "r") as f: config = edict(json.load(f)) # handle dictionary inheritance # only training config is supported for inheritance if "inherit" in config.train and config.train.inherit is not None: inherit_config = get_model_config(config.train["inherit"]).train for key, value in inherit_config.items(): if key not in config.train: config.train[key] = value return edict(config) def update_model_config(config, mode, model_name, model_version=None, strict=False): model_config = get_model_config(model_name, model_version) if model_config is not None: config = {**config, ** flatten({**model_config.model, **model_config[mode]})} elif strict: raise ValueError(f"Config file for model {model_name} not found.") return config def check_choices(name, value, choices): # return # No checks in dev branch if value not in choices: raise ValueError(f"{name} {value} not in supported choices {choices}") KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): """Main entry point to get the config for the model. Args: model_name (str): name of the desired model. mode (str, optional): "train" or "infer". Defaults to 'train'. dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. Keyword Args: key-value pairs of arguments to overwrite the default config. The order of precedence for overwriting the config is (Higher precedence first): # 1. overwrite_kwargs # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json # 4. common_config: Default config for all models specified in COMMON_CONFIG Returns: easydict: The config dictionary for the model. """ check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) check_choices("Mode", mode, ["train", "infer", "eval"]) if mode == "train": check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) config = update_model_config(config, mode, model_name) # update with model version specific config version_name = overwrite_kwargs.get("version_name", config["version_name"]) config = update_model_config(config, mode, model_name, version_name) # update with config version if specified config_version = overwrite_kwargs.get("config_version", None) if config_version is not None: print("Overwriting config with config_version", config_version) config = update_model_config(config, mode, model_name, config_version) # update with overwrite_kwargs # Combined args are useful for hyperparameter search overwrite_kwargs = split_combined_args(overwrite_kwargs) config = {**config, **overwrite_kwargs} # Casting to bool for key in KEYS_TYPE_BOOL: if key in config: config[key] = bool(config[key]) # Model specific post processing of config parse_list(config, "n_attractors") # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: bin_conf = config['bin_conf'] # list of dicts n_bins = overwrite_kwargs['n_bins'] new_bin_conf = [] for conf in bin_conf: conf['n_bins'] = n_bins new_bin_conf.append(conf) config['bin_conf'] = new_bin_conf if mode == "train": orig_dataset = dataset if dataset == "mix": dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader if dataset is not None: config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb if dataset is not None: config['dataset'] = dataset config = {**DATASETS_CONFIG[dataset], **config} config['model'] = model_name typed_config = {k: infer_type(v) for k, v in config.items()} # add hostname to config config['hostname'] = platform.node() return edict(typed_config) def change_dataset(config, new_dataset): config.update(DATASETS_CONFIG[new_dataset]) return config ================================================ FILE: modules/control/proc/zoe/zoedepth/utils/easydict/__init__.py ================================================ """ EasyDict Copy/pasted from https://github.com/makinacorpus/easydict Original author: Mathieu Leplatre """ class EasyDict(dict): """ Get attributes >>> d = EasyDict({'foo':3}) >>> d['foo'] 3 >>> d.foo 3 >>> d.bar Traceback (most recent call last): ... AttributeError: 'EasyDict' object has no attribute 'bar' Works recursively >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) >>> isinstance(d.bar, dict) True >>> d.bar.x 1 Bullet-proof >>> EasyDict({}) {} >>> EasyDict(d={}) {} >>> EasyDict(None) {} >>> d = {'a': 1} >>> EasyDict(**d) {'a': 1} >>> EasyDict((('a', 1), ('b', 2))) {'a': 1, 'b': 2} Set attributes >>> d = EasyDict() >>> d.foo = 3 >>> d.foo 3 >>> d.bar = {'prop': 'value'} >>> d.bar.prop 'value' >>> d {'foo': 3, 'bar': {'prop': 'value'}} >>> d.bar.prop = 'newer' >>> d.bar.prop 'newer' Values extraction >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) >>> isinstance(d.bar, list) True >>> from operator import attrgetter >>> list(map(attrgetter('x'), d.bar)) [1, 3] >>> list(map(attrgetter('y'), d.bar)) [2, 4] >>> d = EasyDict() >>> list(d.keys()) [] >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) >>> d.foo 3 >>> d.bar.x 1 Still like a dict though >>> o = EasyDict({'clean':True}) >>> list(o.items()) [('clean', True)] And like a class >>> class Flower(EasyDict): ... power = 1 ... >>> f = Flower() >>> f.power 1 >>> f = Flower({'height': 12}) >>> f.height 12 >>> f['power'] 1 >>> sorted(f.keys()) ['height', 'power'] update and pop items >>> d = EasyDict(a=1, b='2') >>> e = EasyDict(c=3.0, a=9.0) >>> d.update(e) >>> d.c 3.0 >>> d['c'] 3.0 >>> d.get('c') 3.0 >>> d.update(a=4, b=4) >>> d.b 4 >>> d.pop('a') 4 >>> d.a Traceback (most recent call last): ... AttributeError: 'EasyDict' object has no attribute 'a' """ def __init__(self, d=None, **kwargs): if d is None: d = {} else: d = dict(d) if kwargs: d.update(**kwargs) for k, v in d.items(): setattr(self, k, v) # Class attributes for k in self.__class__.__dict__.keys(): if not (k.startswith('__') and k.endswith('__')) and k not in ('update', 'pop'): setattr(self, k, getattr(self, k)) def __setattr__(self, name, value): if isinstance(value, (list, tuple)): value = [self.__class__(x) if isinstance(x, dict) else x for x in value] elif isinstance(value, dict) and not isinstance(value, self.__class__): value = self.__class__(value) super(EasyDict, self).__setattr__(name, value) super(EasyDict, self).__setitem__(name, value) __setitem__ = __setattr__ def update(self, e=None, **f): d = e or {} d.update(f) for k in d: setattr(self, k, d[k]) def pop(self, k, d=None): delattr(self, k) return super(EasyDict, self).pop(k, d) if __name__ == "__main__": import doctest doctest.testmod() ================================================ FILE: modules/control/processor.py ================================================ import os import time import hashlib import numpy as np from PIL import Image from modules.processing_class import StableDiffusionProcessingControl from modules import shared, images, masking, sd_models from modules.timer import process as process_timer from modules.control import util from modules.control import processors as control_processors debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None debug_log = shared.log.trace if debug else lambda *args, **kwargs: None processors = [ 'None', 'OpenPose', 'DWPose', 'MediaPipe Face', 'Canny', 'Edge', 'LineArt Realistic', 'LineArt Anime', 'HED', 'PidiNet', 'Midas Depth Hybrid', 'Leres Depth', 'Zoe Depth', 'Marigold Depth', 'Normal Bae', 'SegmentAnything', 'MLSD', 'Shuffle', 'DPT Depth Hybrid', 'GLPN Depth', 'Depth Anything', 'Depth Pro', ] def preprocess_image( p:StableDiffusionProcessingControl, pipe, input_image:Image.Image = None, init_image:Image.Image = None, input_mask:Image.Image = None, input_type:str = 0, unit_type:str = 'controlnet', active_process:list = [], active_model:list = [], selected_models:list = [], has_models:bool = False, ): t0 = time.time() jobid = shared.state.begin('Preprocess') # run resize before if (p.resize_mode_before != 0) and (p.resize_name_before != 'None'): if (p.selected_scale_tab_before == 1) and (input_image is not None): p.width_before, p.height_before = int(input_image.width * p.scale_by_before), int(input_image.height * p.scale_by_before) if input_image is not None: debug_log(f'Control resize: op=before image={input_image} width={p.width_before} height={p.height_before} mode={p.resize_mode_before} name={p.resize_name_before} context="{p.resize_context_before}"') p.init_img_hash = getattr(p, 'init_img_hash', hashlib.sha256(input_image.tobytes()).hexdigest()[0:8]) # pylint: disable=attribute-defined-outside-init p.init_img_width = getattr(p, 'init_img_width', input_image.width) # pylint: disable=attribute-defined-outside-init p.init_img_height = getattr(p, 'init_img_height', input_image.height) # pylint: disable=attribute-defined-outside-init input_image = images.resize_image(p.resize_mode_before, input_image, p.width_before, p.height_before, p.resize_name_before, context=p.resize_context_before) if (input_image is not None) and (init_image is not None) and (init_image.size != input_image.size): debug_log(f'Control resize init: image={init_image} target={input_image}') init_image = images.resize_image(resize_mode=1, im=init_image, width=input_image.width, height=input_image.height) if (input_image is not None) and (p.override is not None) and (p.override.size != input_image.size): debug_log(f'Control resize override: image={p.override} target={input_image}') p.override = images.resize_image(resize_mode=1, im=p.override, width=input_image.width, height=input_image.height) if input_image is not None: p.width = input_image.width p.height = input_image.height debug_log(f'Control: input image={input_image}') # run masking if input_mask is not None: p.extra_generation_params["Mask only"] = masking.opts.mask_only if masking.opts.mask_only else None p.extra_generation_params["Mask auto"] = masking.opts.auto_mask if masking.opts.auto_mask != 'None' else None p.extra_generation_params["Mask invert"] = masking.opts.invert if masking.opts.invert else None p.extra_generation_params["Mask blur"] = masking.opts.mask_blur if masking.opts.mask_blur > 0 else None p.extra_generation_params["Mask erode"] = masking.opts.mask_erode if masking.opts.mask_erode > 0 else None p.extra_generation_params["Mask dilate"] = masking.opts.mask_dilate if masking.opts.mask_dilate > 0 else None p.extra_generation_params["Mask model"] = masking.opts.model if masking.opts.model is not None else None masked_image = masking.run_mask(input_image=input_image, input_mask=input_mask, return_type='Masked', invert=p.inpainting_mask_invert==1) if input_mask is not None else input_image else: masked_image = input_image # resize mask if input_mask is not None and p.resize_mode_mask != 0 and p.resize_name_mask != 'None': if p.selected_scale_tab_mask == 1: p.width_mask, p.height_mask = int(input_image.width * p.scale_by_mask), int(input_image.height * p.scale_by_mask) p.width, p.height = p.width_mask, p.height_mask debug_log(f'Control resize: op=mask image={input_mask} width={p.width_mask} height={p.height_mask} mode={p.resize_mode_mask} name={p.resize_name_mask} context="{p.resize_context_mask}"') # run image processors processed_images = [] blended_image = None for i, process in enumerate(active_process): # list[image] debug_log(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}') if p.resize_mode_before != 0: resize_mode = p.resize_mode_before else: resize_mode = 3 if shared.opts.control_aspect_ratio else 1 processed_image = process( image_input=masked_image, width=p.width, height=p.height, mode='RGB', resize_mode=resize_mode, resize_name=p.resize_name_before, scale_tab=p.selected_scale_tab_before, scale_by=p.scale_by_before, ) if processed_image is not None: processed_images.append(processed_image) if shared.opts.control_unload_processor and process.processor_id is not None: control_processors.config[process.processor_id]['dirty'] = True # to force reload process.model = None # blend processed images debug_log(f'Control processed: {len(processed_images)}') if len(processed_images) > 0: try: if len(p.extra_generation_params["Control process"]) == 0: p.extra_generation_params["Control process"] = None else: p.extra_generation_params["Control process"] = ';'.join([p.processor_id for p in active_process if p.processor_id is not None]) except Exception: pass if any(img is None for img in processed_images): shared.log.error('Control: one or more processed images are None') processed_images = [img for img in processed_images if img is not None] if len(processed_images) > 1 and len(active_process) != len(active_model): processed_image = [np.array(i) for i in processed_images] processed_image = util.blend(processed_image) # blend all processed images into one processed_image = Image.fromarray(processed_image) blended_image = processed_image elif len(processed_images) == 1: processed_image = processed_images blended_image = processed_image[0] else: blended_image = [np.array(i) for i in processed_images] blended_image = util.blend(blended_image) # blend all processed images into one blended_image = Image.fromarray(blended_image) if isinstance(selected_models, list) and len(processed_images) == len(selected_models) and len(processed_images) > 0: debug_log(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}') p.init_images = processed_images elif isinstance(selected_models, list) and len(processed_images) != len(selected_models): shared.log.error(f'Control: number of inputs does not match: input={len(processed_images)} models={len(selected_models)}') elif selected_models is not None: p.init_images = processed_image else: debug_log('Control processed: using input direct') processed_image = input_image # conditional assignment possible = sd_models.get_call(pipe).keys() if unit_type == 'reference' and has_models: p.ref_image = p.override or input_image p.task_args.pop('image', None) p.task_args['ref_image'] = p.ref_image debug_log(f'Control: process=None image={p.ref_image}') if p.ref_image is None: shared.log.error('Control: reference mode without image') elif unit_type == 'controlnet' and has_models: if input_type == 0: # Control only if 'control_image' in possible: p.task_args['control_image'] = [p.init_images] if isinstance(p.init_images, Image.Image) else p.init_images elif 'image' in possible: p.task_args['image'] = [p.init_images] if isinstance(p.init_images, Image.Image) else p.init_images if 'control_mode' in possible: p.task_args['control_mode'] = getattr(p, 'control_mode', None) if 'strength' in possible: p.task_args['strength'] = p.denoising_strength p.init_images = None elif input_type == 1: # Init image same as control p.init_images = [p.override or input_image] * max(1, len(active_model)) if 'inpaint_image' in possible: # flex p.task_args['inpaint_image'] = p.init_images[0] if isinstance(p.init_images, list) else p.init_images p.task_args['inpaint_mask'] = Image.new('L', p.task_args['inpaint_image'].size, int(p.denoising_strength * 255)) p.task_args['control_image'] = p.init_images[0] if isinstance(p.init_images, list) else p.init_images p.task_args['width'] = p.width p.task_args['height'] = p.height elif 'control_image' in possible: p.task_args['control_image'] = p.init_images # switch image and control_image if 'control_mode' in possible: p.task_args['control_mode'] = getattr(p, 'control_mode', None) if 'strength' in possible: p.task_args['strength'] = p.denoising_strength elif input_type == 2: # Separate init image if init_image is None: shared.log.warning('Control: separate init image not provided') init_image = input_image if 'inpaint_image' in possible: # flex p.task_args['inpaint_image'] = p.init_images[0] if isinstance(p.init_images, list) else p.init_images p.task_args['inpaint_mask'] = Image.new('L', p.task_args['inpaint_image'].size, int(p.denoising_strength * 255)) p.task_args['control_image'] = p.init_images[0] if isinstance(p.init_images, list) else p.init_images p.task_args['width'] = p.width p.task_args['height'] = p.height elif 'control_image' in possible: p.task_args['control_image'] = p.init_images # switch image and control_image if 'control_mode' in possible: p.task_args['control_mode'] = getattr(p, 'control_mode', None) if 'strength' in possible: p.task_args['strength'] = p.denoising_strength p.init_images = [init_image] * len(active_model) if hasattr(shared.sd_model, 'controlnet') and hasattr(p.task_args, 'control_image') and len(p.task_args['control_image']) > 1 and (shared.sd_model.__class__.__name__ == 'StableDiffusionXLControlNetUnionPipeline'): # special case for controlnet-union p.task_args['control_image'] = [[x] for x in p.task_args['control_image']] p.task_args['control_mode'] = [[x] for x in p.task_args['control_mode']] # determine txt2img, img2img, inpaint pipeline if unit_type == 'reference' and has_models: # special case p.is_control = True shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) elif not has_models: # run in txt2img/img2img/inpaint mode if input_mask is not None: p.task_args['strength'] = p.denoising_strength p.image_mask = input_mask p.init_images = input_image if isinstance(input_image, list) else [input_image] shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) elif processed_image is not None: p.init_images = processed_image if isinstance(processed_image, list) else [processed_image] shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) else: p.init_hr(p.scale_by, p.resize_name, force=True) shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) elif has_models: # actual control p.is_control = True if input_mask is not None: p.task_args['strength'] = p.denoising_strength p.image_mask = input_mask shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # only controlnet supports inpaint if hasattr(p, 'init_images') and p.init_images is not None: shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) # only controlnet supports img2img else: shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) if hasattr(p, 'init_images') and p.init_images is not None and 'image' in possible: p.task_args['image'] = p.init_images # need to set explicitly for txt2img p.init_images = None if unit_type == 'lite': if input_type == 0: shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.TEXT_2_IMAGE) shared.sd_model.no_task_switch = True elif input_type == 1: p.init_images = [input_image] elif input_type == 2: if init_image is None: shared.log.warning('Control: separate init image not provided') init_image = input_image p.init_images = [init_image] t1 = time.time() process_timer.add('proc', t1-t0) shared.state.end(jobid) return processed_image, blended_image ================================================ FILE: modules/control/processors.py ================================================ import os import time import numpy as np from PIL import Image from installer import log from modules.errors import display from modules import devices, images models = {} cache_dir = 'models/control/processors' debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: CONTROL') config = { # placeholder 'None': {}, # pose models 'OpenPose': {'class': None, 'checkpoint': True, 'params': {'include_body': True, 'include_hand': False, 'include_face': False}}, 'DWPose': {'class': None, 'checkpoint': False, 'model': 'Tiny', 'params': {'min_confidence': 0.3}}, 'MediaPipe Face': {'class': None, 'checkpoint': False, 'params': {'max_faces': 1, 'min_confidence': 0.5}}, # outline models 'Canny': {'class': None, 'checkpoint': False, 'params': {'low_threshold': 100, 'high_threshold': 200}}, 'Edge': {'class': None, 'checkpoint': False, 'params': {'pf': True, 'mode': 'edge'}}, 'LineArt Realistic': {'class': None, 'checkpoint': True, 'params': {'coarse': False}}, 'LineArt Anime': {'class': None, 'checkpoint': True, 'params': {}}, 'HED': {'class': None, 'checkpoint': True, 'params': {'scribble': False, 'safe': False}}, 'PidiNet': {'class': None, 'checkpoint': True, 'params': {'scribble': False, 'safe': False, 'apply_filter': False}}, # depth models 'Midas Depth Hybrid': {'class': None, 'checkpoint': True, 'params': {'bg_th': 0.1, 'depth_and_normal': False}}, 'Leres Depth': {'class': None, 'checkpoint': True, 'params': {'boost': False, 'thr_a':0, 'thr_b':0}}, 'Zoe Depth': {'class': None, 'checkpoint': True, 'params': {'gamma_corrected': False}, 'load_config': {'pretrained_model_or_path': 'halffried/gyre_zoedepth', 'filename': 'ZoeD_M12_N.safetensors', 'model_type': "zoedepth"}}, 'Marigold Depth': {'class': None, 'checkpoint': True, 'params': {'denoising_steps': 10, 'ensemble_size': 10, 'processing_res': 512, 'match_input_res': True, 'color_map': 'None'}, 'load_config': {'pretrained_model_or_path': 'Bingxin/Marigold'}}, 'Normal Bae': {'class': None, 'checkpoint': True, 'params': {}}, # segmentation models 'SegmentAnything': {'class': None, 'checkpoint': True, 'model': 'Base', 'params': {}}, # other models 'MLSD': {'class': None, 'checkpoint': True, 'params': {'thr_v': 0.1, 'thr_d': 0.1}}, 'Shuffle': {'class': None, 'checkpoint': False, 'params': {}}, 'DPT Depth Hybrid': {'class': None, 'checkpoint': False, 'params': {}}, 'GLPN Depth': {'class': None, 'checkpoint': False, 'params': {}}, 'Depth Anything': {'class': None, 'checkpoint': True, 'load_config': {'pretrained_model_or_path': 'LiheYoung/depth_anything_vitl14' }, 'params': { 'color_map': 'inferno' }}, 'Depth Pro': {'class': None, 'checkpoint': True, 'load_config': {'pretrained_model_or_path': 'apple/DepthPro-hf'}, 'params': {'color_map': 'inferno'}}, # 'Midas Depth Large': {'class': MidasDetector, 'checkpoint': True, 'params': {'bg_th': 0.1, 'depth_and_normal': False}, 'load_config': {'pretrained_model_or_path': 'Intel/dpt-large', 'model_type': "dpt_large", 'filename': ''}}, # 'Zoe Depth Zoe': {'class': ZoeDetector, 'checkpoint': True, 'params': {}}, # 'Zoe Depth NK': {'class': ZoeDetector, 'checkpoint': True, 'params': {}, 'load_config': {'pretrained_model_or_path': 'halffried/gyre_zoedepth', 'filename': 'ZoeD_M12_NK.safetensors', 'model_type': "zoedepth_nk"}}, } def delay_load_config(): global config # pylint: disable=global-statement from modules.control.proc.hed import HEDdetector from modules.control.proc.canny import CannyDetector from modules.control.proc.edge import EdgeDetector from modules.control.proc.lineart import LineartDetector from modules.control.proc.lineart_anime import LineartAnimeDetector from modules.control.proc.pidi import PidiNetDetector from modules.control.proc.mediapipe_face import MediapipeFaceDetector from modules.control.proc.shuffle import ContentShuffleDetector from modules.control.proc.leres import LeresDetector from modules.control.proc.midas import MidasDetector from modules.control.proc.mlsd import MLSDdetector from modules.control.proc.normalbae import NormalBaeDetector from modules.control.proc.openpose import OpenposeDetector from modules.control.proc.dwpose import DWposeDetector from modules.control.proc.segment_anything import SamDetector from modules.control.proc.zoe import ZoeDetector from modules.control.proc.marigold import MarigoldDetector from modules.control.proc.dpt import DPTDetector from modules.control.proc.glpn import GLPNDetector from modules.control.proc.depth_anything import DepthAnythingDetector from modules.control.proc.depth_pro import DepthProDetector config = { # placeholder 'None': {}, # pose models 'OpenPose': {'class': OpenposeDetector, 'checkpoint': True, 'params': {'include_body': True, 'include_hand': False, 'include_face': False}}, 'DWPose': {'class': DWposeDetector, 'checkpoint': False, 'model': 'Tiny', 'params': {'min_confidence': 0.3}}, 'MediaPipe Face': {'class': MediapipeFaceDetector, 'checkpoint': False, 'params': {'max_faces': 1, 'min_confidence': 0.5}}, # outline models 'Canny': {'class': CannyDetector, 'checkpoint': False, 'params': {'low_threshold': 100, 'high_threshold': 200}}, 'Edge': {'class': EdgeDetector, 'checkpoint': False, 'params': {'pf': True, 'mode': 'edge'}}, 'LineArt Realistic': {'class': LineartDetector, 'checkpoint': True, 'params': {'coarse': False}}, 'LineArt Anime': {'class': LineartAnimeDetector, 'checkpoint': True, 'params': {}}, 'HED': {'class': HEDdetector, 'checkpoint': True, 'params': {'scribble': False, 'safe': False}}, 'PidiNet': {'class': PidiNetDetector, 'checkpoint': True, 'params': {'scribble': False, 'safe': False, 'apply_filter': False}}, # depth models 'Midas Depth Hybrid': {'class': MidasDetector, 'checkpoint': True, 'params': {'bg_th': 0.1, 'depth_and_normal': False}}, 'Leres Depth': {'class': LeresDetector, 'checkpoint': True, 'params': {'boost': False, 'thr_a':0, 'thr_b':0}}, 'Zoe Depth': {'class': ZoeDetector, 'checkpoint': True, 'params': {'gamma_corrected': False}, 'load_config': {'pretrained_model_or_path': 'halffried/gyre_zoedepth', 'filename': 'ZoeD_M12_N.safetensors', 'model_type': "zoedepth"}}, 'Marigold Depth': {'class': MarigoldDetector, 'checkpoint': True, 'params': {'denoising_steps': 10, 'ensemble_size': 10, 'processing_res': 512, 'match_input_res': True, 'color_map': 'None'}, 'load_config': {'pretrained_model_or_path': 'Bingxin/Marigold'}}, 'Normal Bae': {'class': NormalBaeDetector, 'checkpoint': True, 'params': {}}, # segmentation models 'SegmentAnything': {'class': SamDetector, 'checkpoint': True, 'model': 'Base', 'params': {}}, # other models 'MLSD': {'class': MLSDdetector, 'checkpoint': True, 'params': {'thr_v': 0.1, 'thr_d': 0.1}}, 'Shuffle': {'class': ContentShuffleDetector, 'checkpoint': False, 'params': {}}, 'DPT Depth Hybrid': {'class': DPTDetector, 'checkpoint': False, 'params': {}}, 'GLPN Depth': {'class': GLPNDetector, 'checkpoint': False, 'params': {}}, 'Depth Anything': {'class': DepthAnythingDetector, 'checkpoint': True, 'load_config': {'pretrained_model_or_path': 'LiheYoung/depth_anything_vitl14' }, 'params': { 'color_map': 'inferno' }}, 'Depth Pro': {'class': DepthProDetector, 'checkpoint': True, 'load_config': {'pretrained_model_or_path': 'apple/DepthPro-hf'}, 'params': {'color_map': 'inferno'}}, # 'Midas Depth Large': {'class': MidasDetector, 'checkpoint': True, 'params': {'bg_th': 0.1, 'depth_and_normal': False}, 'load_config': {'pretrained_model_or_path': 'Intel/dpt-large', 'model_type': "dpt_large", 'filename': ''}}, # 'Zoe Depth Zoe': {'class': ZoeDetector, 'checkpoint': True, 'params': {}}, # 'Zoe Depth NK': {'class': ZoeDetector, 'checkpoint': True, 'params': {}, 'load_config': {'pretrained_model_or_path': 'halffried/gyre_zoedepth', 'filename': 'ZoeD_M12_NK.safetensors', 'model_type': "zoedepth_nk"}}, } def list_models(refresh=False): global models # pylint: disable=global-statement if not refresh and len(models) > 0: return models models = list(config) debug(f'Control list processors: path={cache_dir} models={models}') return models def update_settings(*settings): debug(f'Control settings: {settings}') def update(what, val): processor_id = what[0] if len(what) == 2 and config[processor_id][what[1]] != val: config[processor_id][what[1]] = val config[processor_id]['dirty'] = True log.debug(f'Control settings: id="{processor_id}" {what[-1]}={val}') elif len(what) == 3 and config[processor_id][what[1]][what[2]] != val: config[processor_id][what[1]][what[2]] = val config[processor_id]['dirty'] = True log.debug(f'Control settings: id="{processor_id}" {what[-1]}={val}') elif len(what) == 4 and config[processor_id][what[1]][what[2]][what[3]] != val: config[processor_id][what[1]][what[2]][what[3]] = val config[processor_id]['dirty'] = True log.debug(f'Control settings: id="{processor_id}" {what[-1]}={val}') update(['HED', 'params', 'scribble'], settings[0]) update(['Midas Depth Hybrid', 'params', 'bg_th'], settings[1]) update(['Midas Depth Hybrid', 'params', 'depth_and_normal'], settings[2]) update(['MLSD', 'params', 'thr_v'], settings[3]) update(['MLSD', 'params', 'thr_d'], settings[4]) update(['OpenPose', 'params', 'include_body'], settings[5]) update(['OpenPose', 'params', 'include_hand'], settings[6]) update(['OpenPose', 'params', 'include_face'], settings[7]) update(['PidiNet', 'params', 'scribble'], settings[8]) update(['PidiNet', 'params', 'apply_filter'], settings[9]) update(['LineArt Realistic', 'params', 'coarse'], settings[10]) update(['Leres Depth', 'params', 'boost'], settings[11]) update(['Leres Depth', 'params', 'thr_a'], settings[12]) update(['Leres Depth', 'params', 'thr_b'], settings[13]) update(['MediaPipe Face', 'params', 'max_faces'], settings[14]) update(['MediaPipe Face', 'params', 'min_confidence'], settings[15]) update(['Canny', 'params', 'low_threshold'], settings[16]) update(['Canny', 'params', 'high_threshold'], settings[17]) update(['DWPose', 'model'], settings[18]) update(['DWPose', 'params', 'min_confidence'], settings[19]) update(['SegmentAnything', 'model'], settings[20]) update(['Edge', 'params', 'pf'], settings[21]) update(['Edge', 'params', 'mode'], settings[22]) update(['Zoe Depth', 'params', 'gamma_corrected'], settings[23]) update(['Marigold Depth', 'params', 'color_map'], settings[24]) update(['Marigold Depth', 'params', 'denoising_steps'], settings[25]) update(['Marigold Depth', 'params', 'ensemble_size'], settings[26]) update(['Depth Anything', 'params', 'color_map'], settings[27]) update(['Depth Pro', 'params', 'color_map'], settings[28]) class Processor(): def __init__(self, processor_id: str = None, resize = True): self.model = None self.processor_id = None self.override = None self.resize = resize self.reset() self.config(processor_id) if processor_id is not None: self.load() def __str__(self): return f' Processor(id={self.processor_id} model={self.model.__class__.__name__})' if self.processor_id and self.model else '' def reset(self, processor_id: str = None): if self.model is not None: debug(f'Control Processor unloaded: id="{self.processor_id}"') self.model = None self.processor_id = processor_id devices.torch_gc(force=True, reason='processor') self.load_config = { 'cache_dir': cache_dir } from modules.shared import opts if opts.offline_mode: self.load_config["local_files_only"] = True os.environ['HF_HUB_OFFLINE'] = '1' else: os.environ.pop('HF_HUB_OFFLINE', None) os.unsetenv('HF_HUB_OFFLINE') def config(self, processor_id = None): if processor_id is not None: self.processor_id = processor_id from_config = config.get(self.processor_id, {}).get('load_config', None) """ if load_config is not None: for k, v in load_config.items(): self.load_config[k] = v """ if from_config is not None: for k, v in from_config.items(): self.load_config[k] = v def load(self, processor_id: str = None, force: bool = True) -> str: from modules.shared import state try: t0 = time.time() processor_id = processor_id or self.processor_id if processor_id is None or processor_id == 'None': self.reset() return '' if self.processor_id != processor_id: self.reset() self.config(processor_id) else: if not force and self.model is not None: # log.debug(f'Control Processor: id={processor_id} already loaded') return '' if processor_id not in config: log.error(f'Control Processor unknown: id="{processor_id}" available={list(config)}') return f'Processor failed to load: {processor_id}' cls = config[processor_id]['class'] if cls is None: delay_load_config() cls = config[processor_id]['class'] # log.debug(f'Control Processor loading: id="{processor_id}" class={cls.__name__}') debug(f'Control Processor config={self.load_config}') jobid = state.begin('Load processor') if 'DWPose' in processor_id: det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth' if 'Tiny' == config['DWPose']['model']: pose_config = 'config/rtmpose-t_8xb64-270e_coco-ubody-wholebody-256x192.py' pose_ckpt = 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-tt_ucoco.pth' elif 'Medium' == config['DWPose']['model']: pose_config = 'config/rtmpose-m_8xb64-270e_coco-ubody-wholebody-256x192.py' pose_ckpt = 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-mm_ucoco.pth' elif 'Large' == config['DWPose']['model']: pose_config = 'config/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' pose_ckpt = 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.pth' else: log.error(f'Control Processor load failed: id="{processor_id}" error=unknown model type') return f'Processor failed to load: {processor_id}' self.model = cls(det_ckpt=det_ckpt, pose_config=pose_config, pose_ckpt=pose_ckpt, device="cpu") elif 'SegmentAnything' in processor_id: if 'Base' == config['SegmentAnything']['model']: self.model = cls.from_pretrained(model_path = 'segments-arnaud/sam_vit_b', filename='sam_vit_b_01ec64.pth', model_type='vit_b', **self.load_config) elif 'Large' == config['SegmentAnything']['model']: self.model = cls.from_pretrained(model_path = 'segments-arnaud/sam_vit_l', filename='sam_vit_l_0b3195.pth', model_type='vit_l', **self.load_config) else: log.error(f'Control Processor load failed: id="{processor_id}" error=unknown model type') return f'Processor failed to load: {processor_id}' elif config[processor_id].get('load_config', None) is not None: self.model = cls.from_pretrained(**self.load_config) elif config[processor_id]['checkpoint']: self.model = cls.from_pretrained("lllyasviel/Annotators", **self.load_config) else: self.model = cls() # class instance only t1 = time.time() state.end(jobid) self.processor_id = processor_id log.debug(f'Control Processor loaded: id="{processor_id}" class={self.model.__class__.__name__} time={t1-t0:.2f}') return f'Processor loaded: {processor_id}' except Exception as e: log.error(f'Control Processor load failed: id="{processor_id}" error={e}') display(e, 'Control Processor load') return f'Processor load filed: {processor_id}' def __call__(self, image_input: Image, mode: str = 'RGB', width: int = 0, height: int = 0, resize_mode: int = 0, resize_name: str = 'None', scale_tab: int = 1, scale_by: float = 1.0, local_config: dict = {}): if self.override is not None: debug(f'Control Processor: id="{self.processor_id}" override={self.override}') width = image_input.width if image_input is not None else width height = image_input.height if image_input is not None else height if (width != self.override.width) or (height != self.override.height): debug(f'Control resize: op=override image={self.override} width={width} height={height} mode={resize_mode} name={resize_name}') image_input = images.resize_image(resize_mode, self.override, width, height, resize_name) else: image_input = self.override if resize_mode != 0 and resize_name != 'None': if scale_tab == 1: width_before, height_before = int(image_input.width * scale_by), int(image_input.height * scale_by) debug(f'Control resize: op=before image={image_input} width={width_before} height={height_before} mode={resize_mode} name={resize_name}') image_input = images.resize_image(resize_mode, image_input, width_before, height_before, resize_name) if self.processor_id is None or self.processor_id == 'None': return image_input image_process = image_input if image_input is None: # log.error('Control Processor: no input') return image_process if isinstance(image_input, list): image_input = image_input[0] if self.processor_id not in config: return image_process if config[self.processor_id].get('dirty', False): processor_id = self.processor_id config[processor_id].pop('dirty') self.reset() self.load(processor_id) if self.model is None: # log.error('Control Processor: model not loaded') return image_process try: t0 = time.time() kwargs = config.get(self.processor_id, {}).get('params', None) if kwargs: kwargs.update(local_config) if self.resize: image_resized = image_input.resize((512, 512), Image.Resampling.LANCZOS) else: image_resized = image_input with devices.inference_context(): image_process = self.model(image_resized, **kwargs) if image_process is None: log.error(f'Control Processor: id="{self.processor_id}" no image') return image_input if isinstance(image_process, np.ndarray): if np.max(image_process) < 2: image_process = (255.0 * image_process).astype(np.uint8) image_process = Image.fromarray(image_process, 'L') if self.resize and image_process.size != image_input.size: image_process = image_process.resize(image_input.size, Image.Resampling.LANCZOS) t1 = time.time() log.debug(f'Control Processor: id="{self.processor_id}" mode={mode} args={kwargs} time={t1-t0:.2f}') except Exception as e: log.error(f'Control Processor failed: id="{self.processor_id}" error={e}') display(e, 'Control Processor') if mode != 'RGB': image_process = image_process.convert(mode) return image_process def preview(self): import modules.ui_control_helpers as helpers input_image = helpers.input_source if isinstance(input_image, list): input_image = input_image[0] debug('Control process preview') return self.__call__(input_image) ================================================ FILE: modules/control/run.py ================================================ import os import sys from typing import List, Union import cv2 from PIL import Image from modules.control import util # helper functions from modules.control import unit # control units from modules.control import processors # image preprocessors from modules.control import tile # tiling module from modules.control.units import controlnet # lllyasviel ControlNet from modules.control.units import xs # VisLearn ControlNet-XS from modules.control.units import lite # Kohya ControlLLLite from modules.control.units import t2iadapter # TencentARC T2I-Adapter from modules.control.units import reference # ControlNet-Reference from modules.control.processor import preprocess_image from modules import devices, shared, errors, processing, images, sd_models, sd_vae, scripts_manager, masking from modules.processing_class import StableDiffusionProcessingControl from modules.ui_common import infotext_to_html from modules.api import script from modules.generation_parameters_copypaste import create_override_settings_dict from modules.paths import resolve_output_path debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None debug_log = shared.log.trace if debug else lambda *args, **kwargs: None pipe = None instance = None original_pipeline = None p_extra_args = {} unified_models = ['Flex2Pipeline'] # models that have controlnet builtin def restore_pipeline(): global pipe, instance # pylint: disable=global-statement if instance is not None and hasattr(instance, 'restore'): instance.restore() if (original_pipeline is not None) and (original_pipeline.__class__.__name__ != shared.sd_model.__class__.__name__): if debug: fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access shared.log.trace(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__} fn={fn}') shared.sd_model = original_pipeline pipe = None instance = None devices.torch_gc() def terminate(msg): restore_pipeline() shared.log.error(f'Control terminated: {msg}') return msg def is_unified_model(): return shared.sd_model.__class__.__name__ in unified_models def has_inputs(inputs): current = inputs or [] current = current if isinstance(current, list) else [current] current = [input for input in current if input is not None] if current is None or len(current) == 0: return False return True def set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, active_units, control_conditioning, control_guidance_start, control_guidance_end, inits=None, inputs=None): global pipe, instance # pylint: disable=global-statement pipe = None if has_models and not has_inputs(inits) and not has_inputs(inputs): if not any(has_inputs(u.override) for u in active_units if u.enabled): # check overrides shared.log.error('Control: no input images') return pipe if has_models: p.ops.append('control') p.extra_generation_params["Control type"] = unit_type # overriden later with pretty-print p.extra_generation_params["Control model"] = ';'.join([(m.model_id or '') for m in active_model if m.model is not None]) p.extra_generation_params["Control conditioning"] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning] p.extra_generation_params['Control start'] = control_guidance_start if isinstance(control_guidance_start, list) else [control_guidance_start] p.extra_generation_params['Control end'] = control_guidance_end if isinstance(control_guidance_end, list) else [control_guidance_end] p.extra_generation_params["Control conditioning"] = ';'.join([str(c) for c in p.extra_generation_params["Control conditioning"]]) p.extra_generation_params['Control start'] = ';'.join([str(c) for c in p.extra_generation_params['Control start']]) p.extra_generation_params['Control end'] = ';'.join([str(c) for c in p.extra_generation_params['Control end']]) if unit_type == 't2i adapter' and has_models: p.extra_generation_params["Control type"] = 'T2I-Adapter' p.task_args['adapter_conditioning_scale'] = control_conditioning instance = t2iadapter.AdapterPipeline(selected_models, shared.sd_model) pipe = instance.pipeline if inits is not None: shared.log.warning('Control: T2I-Adapter does not support separate init image') elif unit_type == 'controlnet' and has_models: p.extra_generation_params["Control type"] = 'ControlNet' if shared.sd_model_type == 'f1': p.task_args['controlnet_conditioning_scale'] = control_conditioning if isinstance(control_conditioning, list) else [control_conditioning] else: p.task_args['controlnet_conditioning_scale'] = control_conditioning p.task_args['control_guidance_start'] = control_guidance_start p.task_args['control_guidance_end'] = control_guidance_end p.task_args['guess_mode'] = p.guess_mode if not is_unified_model(): instance = controlnet.ControlNetPipeline(selected_models, shared.sd_model, p=p) pipe = instance.pipeline else: pipe = shared.sd_model elif unit_type == 'xs' and has_models: p.extra_generation_params["Control type"] = 'ControlNet-XS' p.controlnet_conditioning_scale = control_conditioning p.control_guidance_start = control_guidance_start p.control_guidance_end = control_guidance_end instance = xs.ControlNetXSPipeline(selected_models, shared.sd_model) pipe = instance.pipeline if inits is not None: shared.log.warning('Control: ControlNet-XS does not support separate init image') elif unit_type == 'lite' and has_models: p.extra_generation_params["Control type"] = 'ControlLLLite' p.controlnet_conditioning_scale = control_conditioning instance = lite.ControlLLitePipeline(shared.sd_model) pipe = instance.pipeline if inits is not None: shared.log.warning('Control: ControlLLLite does not support separate init image') elif unit_type == 'reference' and has_models: p.extra_generation_params["Control type"] = 'Reference' p.extra_generation_params["Control attention"] = p.attention p.task_args['reference_attn'] = 'Attention' in p.attention p.task_args['reference_adain'] = 'Adain' in p.attention p.task_args['attention_auto_machine_weight'] = p.query_weight p.task_args['gn_auto_machine_weight'] = p.adain_weight p.task_args['style_fidelity'] = p.fidelity instance = reference.ReferencePipeline(shared.sd_model) pipe = instance.pipeline if inits is not None: shared.log.warning('Control: ControlNet-XS does not support separate init image') else: # run in txt2img/img2img mode if len(active_strength) > 0: p.strength = active_strength[0] pipe = shared.sd_model instance = None if (pipe is not None) and (pipe.__class__.__name__ != shared.sd_model.__class__.__name__): sd_models.copy_diffuser_options(pipe, shared.sd_model) # copy options from original pipeline debug_log(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') return pipe def check_active(p, unit_type, units): active_process: List[processors.Processor] = [] # all active preprocessors active_model: List[Union[controlnet.ControlNet, xs.ControlNetXS, t2iadapter.Adapter]] = [] # all active models active_strength: List[float] = [] # strength factors for all active models active_start: List[float] = [] # start step for all active models active_end: List[float] = [] # end step for all active models active_units: List[unit.Unit] = [] # all active units num_units = 0 for u in units: if u.type != unit_type: continue num_units += 1 debug_log(f'Control unit: i={num_units} type={u.type} enabled={u.enabled} cn={u.controlnet} proc={u.process}') if not u.enabled: if u.controlnet is not None and u.controlnet.model is not None: debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') sd_models.move_model(u.controlnet.model, devices.cpu) continue if u.controlnet is not None and u.controlnet.model is not None: debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') sd_models.move_model(u.controlnet.model, devices.device) if unit_type == 't2i adapter' and u.adapter.model is not None: active_process.append(u.process) active_model.append(u.adapter) active_strength.append(float(u.strength)) p.adapter_conditioning_factor = u.factor active_units.append(u) shared.log.debug(f'Control T2I-Adapter unit: i={num_units} process="{u.process.processor_id}" model="{u.adapter.model_id}" strength={u.strength} factor={u.factor}') elif unit_type == 'controlnet' and (u.controlnet.model is not None or is_unified_model()): active_process.append(u.process) active_model.append(u.controlnet) active_strength.append(float(u.strength)) active_start.append(float(u.start)) active_end.append(float(u.end)) p.guess_mode = u.guess active_units.append(u) if isinstance(u.mode, str): if not hasattr(p, 'control_mode'): p.control_mode = [] p.control_mode.append(u.choices.index(u.mode) if u.mode in u.choices else 0) p.is_tile = p.is_tile or 'tile' in u.mode.lower() p.control_tile = u.tile p.extra_generation_params["Control mode"] = u.mode shared.log.debug(f'Control unit: i={num_units} type=ControlNet process="{u.process.processor_id}" model="{u.controlnet.model_id}" strength={u.strength} guess={u.guess} start={u.start} end={u.end} mode={u.mode}') elif unit_type == 'xs' and u.controlnet.model is not None: active_process.append(u.process) active_model.append(u.controlnet) active_strength.append(float(u.strength)) active_start.append(float(u.start)) active_end.append(float(u.end)) active_units.append(u) shared.log.debug(f'Control unit: i={num_units} type=ControlNetXS process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') elif unit_type == 'lite' and u.controlnet.model is not None: active_process.append(u.process) active_model.append(u.controlnet) active_strength.append(float(u.strength)) active_units.append(u) shared.log.debug(f'Control unit: i={num_units} type=ControlLLite process={u.process.processor_id} model={u.controlnet.model_id} strength={u.strength} guess={u.guess} start={u.start} end={u.end}') elif unit_type == 'reference': p.override = u.override p.attention = u.attention p.query_weight = float(u.query_weight) p.adain_weight = float(u.adain_weight) p.fidelity = u.fidelity active_units.append(u) shared.log.debug('Control Reference unit') else: if u.process.processor_id is not None: active_process.append(u.process) active_units.append(u) shared.log.debug(f'Control unit: i={num_units} type=Process process={u.process.processor_id}') active_strength.append(float(u.strength)) debug_log(f'Control active: process={len(active_process)} model={len(active_model)}') return active_process, active_model, active_strength, active_start, active_end, active_units def check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end): has_models = False selected_models: List[Union[controlnet.ControlNetModel, xs.ControlNetXSModel, t2iadapter.AdapterModel]] = None control_conditioning = None control_guidance_start = None control_guidance_end = None if unit_type == 't2i adapter' or unit_type == 'controlnet' or unit_type == 'xs' or unit_type == 'lite': if len(active_model) == 0: selected_models = None elif len(active_model) == 1: selected_models = active_model[0].model if active_model[0].model is not None else None p.is_tile = p.is_tile or 'tile' in (active_model[0].model_id or '').lower() has_models = (selected_models is not None) or is_unified_model() control_conditioning = active_strength[0] if len(active_strength) > 0 else 1 # strength or list[strength] control_guidance_start = active_start[0] if len(active_start) > 0 else 0 control_guidance_end = active_end[0] if len(active_end) > 0 else 1 else: selected_models = [m.model for m in active_model if m.model is not None] has_models = len(selected_models) > 0 control_conditioning = active_strength[0] if len(active_strength) == 1 else list(active_strength) # strength or list[strength] control_guidance_start = active_start[0] if len(active_start) == 1 else list(active_start) control_guidance_end = active_end[0] if len(active_end) == 1 else list(active_end) elif unit_type == 'reference': has_models = any(u.enabled for u in units if u.type == 'reference') else: pass return has_models, selected_models, control_conditioning, control_guidance_start, control_guidance_end def control_set(kwargs): if kwargs: global p_extra_args # pylint: disable=global-statement p_extra_args = {} debug_log(f'Control extra args: {kwargs}') for k, v in kwargs.items(): p_extra_args[k] = v def init_units(units: List[unit.Unit]): for u in units: if not u.enabled: continue if u.process_name is not None and u.process_name != '' and u.process_name != 'None': u.process.load(u.process_name, force=False) if u.model_name is not None and u.model_name != '' and u.model_name != 'None': if u.type == 't2i adapter': u.adapter.load(u.model_name, force=False) else: u.controlnet.load(u.model_name, force=False) u.update_choices(u.model_name) if u.process is not None and u.process.override is None and u.override is not None: u.process.override = u.override def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg units: List[unit.Unit] = [], inputs: List[Image.Image] = [], inits: List[Image.Image] = [], mask: Image.Image = None, unit_type: str = None, is_generator: bool = True, input_type: int = 0, prompt: str = '', negative_prompt: str = '', styles: List[str] = [], steps: int = 20, sampler_index: int = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, guidance_name: str = 'Default', guidance_scale: float = 6.0, guidance_rescale: float = 0.0, guidance_start: float = 0.0, guidance_stop: float = 1.0, cfg_scale: float = 6.0, clip_skip: float = 1.0, image_cfg_scale: float = 6.0, diffusers_guidance_rescale: float = 0.7, pag_scale: float = 0.0, pag_adaptive: float = 0.5, cfg_end: float = 1.0, vae_type: str = 'Full', tiling: bool = False, hidiffusion: bool = False, detailer_enabled: bool = False, detailer_prompt: str = '', detailer_negative: str = '', detailer_steps: int = 10, detailer_strength: float = 0.3, detailer_resolution: int = 1024, hdr_mode: int = 0, hdr_brightness: float = 0, hdr_color: float = 0, hdr_sharpen: float = 0, hdr_clamp: bool = False, hdr_boundary: float = 4.0, hdr_threshold: float = 0.95, hdr_maximize: bool = False, hdr_max_center: float = 0.6, hdr_max_boundary: float = 1.0, hdr_color_picker: str = None, hdr_tint_ratio: float = 0, resize_mode_before: int = 0, resize_name_before: str = 'None', resize_context_before: str = 'None', width_before: int = 512, height_before: int = 512, scale_by_before: float = 1.0, selected_scale_tab_before: int = 0, resize_mode_after: int = 0, resize_name_after: str = 'None', resize_context_after: str = 'None', width_after: int = 0, height_after: int = 0, scale_by_after: float = 1.0, selected_scale_tab_after: int = 0, resize_mode_mask: int = 0, resize_name_mask: str = 'None', resize_context_mask: str = 'None', width_mask: int = 0, height_mask: int = 0, scale_by_mask: float = 1.0, selected_scale_tab_mask: int = 0, denoising_strength: float = 0.3, batch_count: int = 1, batch_size: int = 1, enable_hr: bool = False, hr_sampler_index: int = None, hr_denoising_strength: float = 0.0, hr_resize_mode: int = 0, hr_resize_context: str = 'None', hr_upscaler: str = None, hr_force: bool = False, hr_second_pass_steps: int = 20, hr_scale: float = 1.0, hr_resize_x: int = 0, hr_resize_y: int = 0, refiner_steps: int = 5, refiner_start: float = 0.0, refiner_prompt: str = '', refiner_negative: str = '', video_skip_frames: int = 0, video_type: str = 'None', video_duration: float = 2.0, video_loop: bool = False, video_pad: int = 0, video_interpolate: int = 0, extra: dict = {}, override_script_name: str = None, override_script_args = [], *input_script_args, ): global pipe, original_pipeline # pylint: disable=global-statement if 'refine' in state: enable_hr = True unit.current = units debug_log(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}') init_units(units) if inputs is None or (type(inputs) is list and len(inputs) == 0): inputs = [None] output_images: List[Image.Image] = [] # output images processed_image: Image.Image = None # last processed image if mask is not None and input_type == 0: input_type = 1 # inpaint always requires control_image if sampler_index is None: shared.log.warning('Sampler: invalid') sampler_index = 0 if hr_sampler_index is None: hr_sampler_index = sampler_index if isinstance(extra, list): extra = create_override_settings_dict(extra) p = StableDiffusionProcessingControl( prompt = prompt, negative_prompt = negative_prompt, styles = styles, steps = steps, n_iter = batch_count, batch_size = batch_size, sampler_name = processing.get_sampler_name(sampler_index), seed = seed, subseed = subseed, subseed_strength = subseed_strength, seed_resize_from_h = seed_resize_from_h, seed_resize_from_w = seed_resize_from_w, denoising_strength = denoising_strength, # modular guidance guidance_name = guidance_name, guidance_scale = guidance_scale, guidance_rescale = guidance_rescale, guidance_start = guidance_start, guidance_stop = guidance_stop, # legacy guidance cfg_scale = cfg_scale, cfg_end = cfg_end, clip_skip = clip_skip, image_cfg_scale = image_cfg_scale, diffusers_guidance_rescale = diffusers_guidance_rescale, pag_scale = pag_scale, pag_adaptive = pag_adaptive, # advanced vae_type = vae_type, tiling = tiling, hidiffusion = hidiffusion, # resize width = width_before, height = height_before, width_before = width_before, width_after = width_after, width_mask = width_mask, height_before = height_before, height_after = height_after, height_mask = height_mask, resize_name_before = resize_name_before, resize_name_after = resize_name_after, resize_name_mask = resize_name_mask, resize_mode_before = resize_mode_before if resize_name_before != 'None' and inputs is not None and len(inputs) > 0 else 0, resize_mode_after = resize_mode_after if resize_name_after != 'None' else 0, resize_mode_mask = resize_mode_mask if resize_name_mask != 'None' else 0, resize_context_before = resize_context_before, resize_context_after = resize_context_after, resize_context_mask = resize_context_mask, selected_scale_tab_before = selected_scale_tab_before, selected_scale_tab_after = selected_scale_tab_after, selected_scale_tab_mask = selected_scale_tab_mask, scale_by_before = scale_by_before, scale_by_after = scale_by_after, scale_by_mask = scale_by_mask, # hires enable_hr = enable_hr, hr_sampler_name = processing.get_sampler_name(hr_sampler_index), hr_denoising_strength = hr_denoising_strength, hr_resize_mode = hr_resize_mode if enable_hr else 0, hr_resize_context = hr_resize_context if enable_hr else 'None', hr_upscaler = hr_upscaler if enable_hr else None, hr_force = hr_force, hr_second_pass_steps = hr_second_pass_steps if enable_hr else 0, hr_scale = hr_scale if enable_hr else 1.0, hr_resize_x = hr_resize_x if enable_hr else 0, hr_resize_y = hr_resize_y if enable_hr else 0, # refiner refiner_steps = refiner_steps, refiner_start = refiner_start, refiner_prompt = refiner_prompt, refiner_negative = refiner_negative, # detailer detailer_enabled = detailer_enabled, detailer_prompt = detailer_prompt, detailer_negative = detailer_negative, detailer_steps = detailer_steps, detailer_strength = detailer_strength, detailer_resolution = detailer_resolution, # inpaint inpaint_full_res = masking.opts.mask_only, inpainting_mask_invert = 1 if masking.opts.invert else 0, # hdr hdr_mode=hdr_mode, hdr_brightness=hdr_brightness, hdr_color=hdr_color, hdr_sharpen=hdr_sharpen, hdr_clamp=hdr_clamp, hdr_boundary=hdr_boundary, hdr_threshold=hdr_threshold, hdr_maximize=hdr_maximize, hdr_max_center=hdr_max_center, hdr_max_boundary=hdr_max_boundary, hdr_color_picker=hdr_color_picker, hdr_tint_ratio=hdr_tint_ratio, # path outpath_samples=resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_control_samples), outpath_grids=resolve_output_path(shared.opts.outdir_grids, shared.opts.outdir_control_grids), # overrides override_settings=extra ) p.state = state p.is_tile = False p.init_control = inits or [] p.orig_init_images = inputs # TODO modernui: monkey-patch for missing tabs.select event if p.selected_scale_tab_before == 0 and p.resize_name_before != 'None' and p.scale_by_before != 1 and inputs is not None and len(inputs) > 0: shared.log.debug('Control: override resize mode=before') p.selected_scale_tab_before = 1 if p.selected_scale_tab_after == 0 and p.resize_name_after != 'None' and p.scale_by_after != 1: shared.log.debug('Control: override resize mode=after') p.selected_scale_tab_after = 1 if p.selected_scale_tab_mask == 0 and p.resize_name_mask != 'None' and p.scale_by_mask != 1: shared.log.debug('Control: override resize mode=mask') p.selected_scale_tab_mask = 1 # hires/refine defined outside of main init vae_scale_factor = sd_vae.get_vae_scale_factor() if p.enable_hr and (p.hr_resize_x == 0 or p.hr_resize_y == 0): p.hr_upscale_to_x, p.hr_upscale_to_y = int(vae_scale_factor * int(p.width_before * p.hr_scale / vae_scale_factor)), int(vae_scale_factor * int(p.height_before * p.hr_scale / vae_scale_factor)) elif p.enable_hr and (p.hr_upscale_to_x == 0 or p.hr_upscale_to_y == 0): p.hr_upscale_to_x, p.hr_upscale_to_y = 8 * int(p.hr_resize_x / vae_scale_factor), int(vae_scale_factor * int(p.hr_resize_y / vae_scale_factor)) global p_extra_args # pylint: disable=global-statement for k, v in p_extra_args.items(): setattr(p, k, v) p_extra_args = {} if shared.sd_model is None: shared.log.warning('Aborted: op=control model not loaded') return [], '', '', 'Error: model not loaded' unit_type = unit_type.strip().lower() if unit_type is not None else '' active_process, active_model, active_strength, active_start, active_end, active_units = check_active(p, unit_type, units) has_models, selected_models, control_conditioning, control_guidance_start, control_guidance_end = check_enabled(p, unit_type, units, active_model, active_strength, active_start, active_end) image_txt = '' info_txt = [] p.is_tile = p.is_tile and has_models if is_unified_model(): p.init_images = inputs pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, active_units, control_conditioning, control_guidance_start, control_guidance_end, inits, inputs) debug_log(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}') status = True frame = None video = None output_filename = None index = 0 frames = 0 blended_image = None # set pipeline if pipe is None: return [], '', '', 'Pipeline not set' elif pipe.__class__.__name__ != shared.sd_model.__class__.__name__: original_pipeline = shared.sd_model shared.sd_model = pipe sd_models.move_model(shared.sd_model, shared.device) debug_log(f'Control device={devices.device} dtype={devices.dtype}') sd_models.copy_diffuser_options(shared.sd_model, original_pipeline) # copy options from original pipeline sd_models.set_diffuser_options(shared.sd_model) else: original_pipeline = None try: with devices.inference_context(): if isinstance(inputs, str) and os.path.exists(inputs): # only video, the rest is a list if input_type == 2: # separate init image if isinstance(inits, str) and inits != inputs: shared.log.warning('Control: separate init video not support for video input') input_type = 1 try: video = cv2.VideoCapture(inputs) if not video.isOpened(): if is_generator: yield terminate(f'Video open failed: path={inputs}') return [], '', '', 'Error: video open failed' frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(video.get(cv2.CAP_PROP_FPS)) w, h = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) codec = util.decode_fourcc(video.get(cv2.CAP_PROP_FOURCC)) status, frame = video.read() if status: shared.state.frame_count = 1 + frames // (video_skip_frames + 1) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) shared.log.debug(f'Control: input video: path={inputs} frames={frames} fps={fps} size={w}x{h} codec={codec}') except Exception as e: if is_generator: yield terminate(f'Video open failed: path={inputs} {e}') return [], '', '', 'Error: video open failed' while status: processed_image = None if frame is not None: inputs = [Image.fromarray(frame)] # cv2 to pil for i, input_image in enumerate(inputs): if input_image is not None: p.ops.append('img2img') if pipe is None: # pipe may have been reset externally pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) debug_log(f'Control pipeline reinit: class={pipe.__class__.__name__}') pipe.restore_pipeline = restore_pipeline shared.sd_model.restore_pipeline = restore_pipeline debug_log(f'Control Control image: {i + 1} of {len(inputs)}') if shared.state.skipped: shared.state.skipped = False continue if shared.state.interrupted: shared.state.interrupted = False if is_generator: yield terminate('Interrupted') return [], '', '', 'Interrupted' # get input if isinstance(input_image, str) and os.path.exists(input_image): try: input_image = Image.open(input_image) except Exception as e: shared.log.error(f'Control: image open failed: path={input_image} type=control error={e}') continue # match init input if input_type == 1: debug_log('Control Init image: same as control') init_image = input_image elif inits is None: debug_log('Control Init image: none') init_image = None elif isinstance(inits[i], str): debug_log(f'Control: init image: {inits[i]}') try: init_image = Image.open(inits[i]) except Exception as e: shared.log.error(f'Control: image open failed: path={inits[i]} type=init error={e}') continue else: debug_log(f'Control Init image: {i % len(inits) + 1} of {len(inits)}') init_image = inits[i % len(inits)] if video is not None and index % (video_skip_frames + 1) != 0: index += 1 continue index += 1 processed_image, blended_image = preprocess_image(p, pipe, input_image, init_image, mask, input_type, unit_type, active_process, active_model, selected_models, has_models) if is_generator: yield (None, blended_image, '') # result is control_output, proces_output # final check if has_models: if shared.sd_model.__class__.__name__ not in unified_models: if unit_type in ['controlnet', 't2i adapter', 'lite', 'xs'] \ and p.task_args.get('image', None) is None \ and p.task_args.get('control_image', None) is None \ and getattr(p, 'init_images', None) is None \ and getattr(p, 'image', None) is None: if is_generator: shared.log.debug(f'Control args: {p.task_args}') yield terminate(f'Mode={p.extra_generation_params.get("Control type", None)} input image is none') return [], '', '', 'Error: Input image is none' if unit_type == 'lite': instance.apply(selected_models, processed_image, control_conditioning) # what are we doing? if 'control' in p.ops: p.outpath_samples = resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_control_samples) elif 'img2img' in p.ops: p.outpath_samples = resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_img2img_samples) elif 'txt2img' in p.ops: p.outpath_samples = resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_txt2img_samples) else: # fallback to txt2img p.outpath_samples = resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_txt2img_samples) # pipeline output = None script_run = False if pipe is not None: # run new pipeline debug_log(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}') if sd_models.get_diffusers_task(pipe) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: # force vae back to gpu if not in txt2img mode sd_models.move_model(pipe.vae, devices.device) # init scripts p.scripts = scripts_manager.scripts_control p.script_args = input_script_args or [] if len(p.script_args) == 0: if not p.scripts: p.scripts.initialize_scripts(False) p.script_args = script.init_default_script_args(p.scripts) # init override scripts if override_script_name and override_script_args and len(override_script_name) > 0: selectable_scripts, selectable_script_idx = script.get_selectable_script(override_script_name, p.scripts) if selectable_scripts: for idx in range(len(override_script_args)): p.script_args[selectable_scripts.args_from + idx] = override_script_args[idx] p.script_args[0] = selectable_script_idx + 1 # actual processing processed: processing.Processed = None if p.is_tile: processed: processing.Processed = tile.run_tiling(p, input_image) if processed is None and p.scripts is not None: processed = p.scripts.run(p, *p.script_args) if processed is None: processed: processing.Processed = processing.process_images(p) # run actual pipeline else: script_run = True # postprocessing if p.scripts is not None: processed = p.scripts.after(p, processed, *p.script_args) output = None if processed is not None and processed.images is not None: output = processed.images info_txt = [processed.infotext(p, i) for i in range(len(output))] # output = pipe(**vars(p)).images # alternative direct pipe exec call else: # blend all processed images and return output = processed_image # outputs output = output or [] for _i, output_image in enumerate(output): if output_image is not None: output_images.append(output_image) if shared.opts.include_mask and not script_run: if processed_image is not None and isinstance(processed_image, Image.Image): output_images.append(processed_image) if is_generator and frame is not None and video is not None: image_txt = f'{output_image.width}x{output_image.height}' if output_image is not None else 'None' msg = f'Control output | {index} of {frames} skip {video_skip_frames} | Frame {image_txt}' yield (output_image, blended_image, msg) # result is control_output, proces_output if video is not None and frame is not None: status, frame = video.read() if status: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) debug_log(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}') else: status = False if video is not None: video.release() debug_log(f'Control: pipeline units={len(active_model)} process={len(active_process)} outputs={len(output_images)}') except Exception as e: shared.log.error(f'Control: type={unit_type} units={len(active_model)} {e}') errors.display(e, 'Control') if len(output_images) == 0: output_images = None image_txt = '| Images None' else: image_txt = '' p.init_images = output_images # may be used for hires if video_type != 'None' and isinstance(output_images, list) and 'video' in p.ops: p.do_not_save_grid = True # pylint: disable=attribute-defined-outside-init output_filename = images.save_video(p, filename=None, images=output_images, video_type=video_type, duration=video_duration, loop=video_loop, pad=video_pad, interpolate=video_interpolate, sync=True) if shared.opts.gradio_skip_video: output_filename = '' image_txt = f'| Frames {len(output_images)} | Size {output_images[0].width}x{output_images[0].height}' p.close() restore_pipeline() debug_log(f'Ready: {image_txt}') html_txt = f'

Ready {image_txt}

' if image_txt != '' else '' if len(info_txt) > 0: html_txt = html_txt + infotext_to_html(info_txt[0]) if is_generator: jobid = shared.state.begin('UI') yield (output_images, blended_image, html_txt, output_filename) shared.state.end(jobid) return (output_images, blended_image, html_txt, output_filename) ================================================ FILE: modules/control/test.py ================================================ import math from PIL import Image, ImageChops, ImageDraw from modules import shared, errors, images FONT_SIZE=48 def test_processors(image): from modules.control import processors if image is None: shared.log.error('Image not loaded') return None, None, None res = [] for processor_id in processors.list_models(): if shared.state.interrupted: continue shared.log.info(f'Testing processor: {processor_id}') processor = processors.Processor(processor_id) output = image if processor is None: shared.log.error(f'Processor load failed: id="{processor_id}"') processor_id = f'{processor_id} error' else: output = processor(image) if shared.opts.control_unload_processor: processor.reset() if output.size != image.size: output = output.resize(image.size, Image.Resampling.LANCZOS) if output.mode != image.mode: output = output.convert(image.mode) shared.log.debug(f'Testing processor: input={image} mode={image.mode} output={output} mode={output.mode}') diff = ImageChops.difference(image, output) if not diff.getbbox(): processor_id = f'{processor_id} null' draw = ImageDraw.Draw(output) font = images.get_font(FONT_SIZE) draw.text((10, 10), processor_id, (0,0,0), font=font) draw.text((8, 8), processor_id, (255,255,255), font=font) res.append(output) yield output, None, None, res rows = round(math.sqrt(len(res))) cols = math.ceil(len(res) / rows) w, h = 256, 256 size = (cols * w + cols, rows * h + rows) grid = Image.new('RGB', size=size, color='black') shared.log.info(f'Test processors: images={len(res)} grid={grid}') for i, image in enumerate(res): x = (i % cols * w) + (i % cols) y = (i // cols * h) + (i // cols) thumb = image.copy().convert('RGB') thumb.thumbnail((w, h), Image.Resampling.HAMMING) grid.paste(thumb, box=(x, y)) yield None, grid, None, res return None, grid, None, res # preview_process, output_image, output_video, output_gallery def test_controlnets(prompt, negative, image): from modules import devices, sd_models from modules.control.units import controlnet if image is None: shared.log.error('Image not loaded') return None, None, None res = [] for model_id in controlnet.list_models(): if model_id is None: model_id = 'None' if shared.state.interrupted: continue output = image if model_id != 'None': controlnet = controlnet.ControlNet(model_id=model_id, device=devices.device, dtype=devices.dtype) if controlnet is None: shared.log.error(f'ControlNet load failed: id="{model_id}"') continue shared.log.info(f'Testing ControlNet: {model_id}') pipe = controlnet.ControlNetPipeline(controlnet=controlnet.model, pipeline=shared.sd_model) pipe.pipeline.to(device=devices.device, dtype=devices.dtype) sd_models.set_diffuser_options(pipe) try: output = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') output = output.images[0] except Exception as e: errors.display(e, f'ControlNet {model_id} inference') model_id = f'{model_id} error' pipe.restore() draw = ImageDraw.Draw(output) font = images.get_font(FONT_SIZE) draw.text((10, 10), model_id, (0,0,0), font=font) draw.text((8, 8), model_id, (255,255,255), font=font) res.append(output) yield output, None, None, res rows = round(math.sqrt(len(res))) cols = math.ceil(len(res) / rows) w, h = 256, 256 size = (cols * w + cols, rows * h + rows) grid = Image.new('RGB', size=size, color='black') shared.log.info(f'Test ControlNets: images={len(res)} grid={grid}') for i, image in enumerate(res): x = (i % cols * w) + (i % cols) y = (i // cols * h) + (i // cols) thumb = image.copy().convert('RGB') thumb.thumbnail((w, h), Image.Resampling.HAMMING) grid.paste(thumb, box=(x, y)) yield None, grid, None, res return None, grid, None, res # preview_process, output_image, output_video, output_gallery def test_adapters(prompt, negative, image): from modules import devices, sd_models from modules.control.units import t2iadapter if image is None: shared.log.error('Image not loaded') return None, None, None res = [] for model_id in t2iadapter.list_models(): if model_id is None: model_id = 'None' if shared.state.interrupted: continue output = image.copy() if model_id != 'None': adapter = t2iadapter.Adapter(model_id=model_id, device=devices.device, dtype=devices.dtype) if adapter is None: shared.log.error(f'Adapter load failed: id="{model_id}"') continue shared.log.info(f'Testing Adapter: {model_id}') pipe = t2iadapter.AdapterPipeline(adapter=adapter.model, pipeline=shared.sd_model) pipe.pipeline.to(device=devices.device, dtype=devices.dtype) sd_models.set_diffuser_options(pipe) image = image.convert('L') if 'Canny' in model_id or 'Sketch' in model_id else image.convert('RGB') try: output = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') output = output.images[0] except Exception as e: errors.display(e, f'Adapter {model_id} inference') model_id = f'{model_id} error' pipe.restore() draw = ImageDraw.Draw(output) font = images.get_font(FONT_SIZE) draw.text((10, 10), model_id, (0,0,0), font=font) draw.text((8, 8), model_id, (255,255,255), font=font) res.append(output) yield output, None, None, res rows = round(math.sqrt(len(res))) cols = math.ceil(len(res) / rows) w, h = 256, 256 size = (cols * w + cols, rows * h + rows) grid = Image.new('RGB', size=size, color='black') shared.log.info(f'Test Adapters: images={len(res)} grid={grid}') for i, image in enumerate(res): x = (i % cols * w) + (i % cols) y = (i // cols * h) + (i // cols) thumb = image.copy().convert('RGB') thumb.thumbnail((w, h), Image.Resampling.HAMMING) grid.paste(thumb, box=(x, y)) yield None, grid, None, res return None, grid, None, res # preview_process, output_image, output_video, output_gallery def test_xs(prompt, negative, image): from modules import devices, sd_models from modules.control.units import xs if image is None: shared.log.error('Image not loaded') return None, None, None res = [] for model_id in xs.list_models(): if model_id is None: model_id = 'None' if shared.state.interrupted: continue output = image if model_id != 'None': xs = xs.ControlNetXS(model_id=model_id, device=devices.device, dtype=devices.dtype) if xs is None: shared.log.error(f'ControlNet-XS load failed: id="{model_id}"') continue shared.log.info(f'Testing ControlNet-XS: {model_id}') pipe = xs.ControlNetXSPipeline(controlnet=xs.model, pipeline=shared.sd_model) pipe.pipeline.to(device=devices.device, dtype=devices.dtype) sd_models.set_diffuser_options(pipe) try: output = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') output = output.images[0] except Exception as e: errors.display(e, f'ControlNet-XS {model_id} inference') model_id = f'{model_id} error' pipe.restore() draw = ImageDraw.Draw(output) font = images.get_font(FONT_SIZE) draw.text((10, 10), model_id, (0,0,0), font=font) draw.text((8, 8), model_id, (255,255,255), font=font) res.append(output) yield output, None, None, res rows = round(math.sqrt(len(res))) cols = math.ceil(len(res) / rows) w, h = 256, 256 size = (cols * w + cols, rows * h + rows) grid = Image.new('RGB', size=size, color='black') shared.log.info(f'Test ControlNet-XS: images={len(res)} grid={grid}') for i, image in enumerate(res): x = (i % cols * w) + (i % cols) y = (i // cols * h) + (i // cols) thumb = image.copy().convert('RGB') thumb.thumbnail((w, h), Image.Resampling.HAMMING) grid.paste(thumb, box=(x, y)) yield None, grid, None, res return None, grid, None, res # preview_process, output_image, output_video, output_gallery def test_lite(prompt, negative, image): from modules import devices, sd_models from modules.control.units import lite if image is None: shared.log.error('Image not loaded') return None, None, None res = [] for model_id in lite.list_models(): if model_id is None: model_id = 'None' if shared.state.interrupted: continue output = image if model_id != 'None': lite = lite.ControlLLLite(model_id=model_id, device=devices.device, dtype=devices.dtype) if lite is None: shared.log.error(f'Control-LLite load failed: id="{model_id}"') continue shared.log.info(f'Testing ControlNet-XS: {model_id}') pipe = lite.ControlLLitePipeline(pipeline=shared.sd_model) pipe.apply(controlnet=lite.model, image=image, conditioning=1.0) pipe.pipeline.to(device=devices.device, dtype=devices.dtype) sd_models.set_diffuser_options(pipe) try: output = pipe.pipeline(prompt=prompt, negative_prompt=negative, image=image, num_inference_steps=10, output_type='pil') output = output.images[0] except Exception as e: errors.display(e, f'ControlNet-XS {model_id} inference') model_id = f'{model_id} error' pipe.restore() draw = ImageDraw.Draw(output) font = images.get_font(FONT_SIZE) draw.text((10, 10), model_id, (0,0,0), font=font) draw.text((8, 8), model_id, (255,255,255), font=font) res.append(output) yield output, None, None, res rows = round(math.sqrt(len(res))) cols = math.ceil(len(res) / rows) w, h = 256, 256 size = (cols * w + cols, rows * h + rows) grid = Image.new('RGB', size=size, color='black') shared.log.info(f'Test ControlNet-XS: images={len(res)} grid={grid}') for i, image in enumerate(res): x = (i % cols * w) + (i % cols) y = (i // cols * h) + (i // cols) thumb = image.copy().convert('RGB') thumb.thumbnail((w, h), Image.Resampling.HAMMING) grid.paste(thumb, box=(x, y)) yield None, grid, None, res return None, grid, None, res # preview_process, output_image, output_video, output_gallery ================================================ FILE: modules/control/tile.py ================================================ import time from PIL import Image from modules import shared, processing, images, sd_models, sd_vae def get_tile(image: Image.Image, x: int, y: int, sx: int, sy: int) -> Image.Image: return image.crop(( (x + 0) * image.width // sx, (y + 0) * image.height // sy, (x + 1) * image.width // sx, (y + 1) * image.height // sy )) def set_tile(image: Image.Image, x: int, y: int, tiled: Image.Image): image.paste(tiled, (x * tiled.width, y * tiled.height)) return image def run_tiling(p: processing.StableDiffusionProcessing, input_image: Image.Image) -> processing.Processed: t0 = time.time() # prepare images sx, sy = p.control_tile.split('x') sx = int(sx) sy = int(sy) vae_scale_factor = sd_vae.get_vae_scale_factor() if sx <= 0 or sy <= 0: raise ValueError('Control Tile: invalid tile size') control_image = p.task_args.get('control_image', None) or p.task_args.get('image', None) control_upscaled = None if isinstance(control_image, list) and len(control_image) > 0: w, h = vae_scale_factor * int(sx * control_image[0].width) // vae_scale_factor, vae_scale_factor * int(sy * control_image[0].height) // vae_scale_factor control_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=control_image[0], width=w, height=h, context='add with forward') init_image = p.override or input_image init_upscaled = None if init_image is not None: w, h = vae_scale_factor * int(sx * init_image.width) // vae_scale_factor, vae_scale_factor * int(sy * init_image.height) // vae_scale_factor init_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=init_image, width=w, height=h, context='add with forward') t1 = time.time() shared.log.debug(f'Control Tile: scale={sx}x{sy} resize={"fixed" if sx==sy else "context"} control={control_upscaled} init={init_upscaled} time={t1-t0:.3f}') # stop processing from restoring pipeline on each iteration orig_restore_pipeline = getattr(shared.sd_model, 'restore_pipeline', None) shared.sd_model.restore_pipeline = None # run tiling for x in range(sx): for y in range(sy): shared.log.info(f'Control Tile: tile={x+1}-{sx}/{y+1}-{sy} target={control_upscaled}') shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE) p.init_images = None p.task_args['control_mode'] = p.control_mode p.task_args['strength'] = p.denoising_strength if init_upscaled is not None: p.task_args['image'] = [get_tile(init_upscaled, x, y, sx, sy)] if control_upscaled is not None: p.task_args['control_image'] = [get_tile(control_upscaled, x, y, sx, sy)] processed: processing.Processed = processing.process_images(p) # run actual pipeline if processed is None or len(processed.images) == 0: continue control_upscaled = set_tile(control_upscaled, x, y, processed.images[0]) # post-process p.width = control_upscaled.width p.height = control_upscaled.height processed.images = [control_upscaled] processed.info = processed.infotext(p, 0) processed.infotexts = [processed.info] shared.sd_model.restore_pipeline = orig_restore_pipeline if hasattr(shared.sd_model, 'restore_pipeline') and shared.sd_model.restore_pipeline is not None: shared.sd_model.restore_pipeline() t2 = time.time() shared.log.debug(f'Control Tile: image={control_upscaled} time={t2-t0:.3f}') return processed ================================================ FILE: modules/control/unit.py ================================================ from typing import Union from PIL import Image import gradio as gr from installer import log from modules.control import processors from modules.control.units import controlnet from modules.control.units import xs from modules.control.units import lite from modules.control.units import t2iadapter from modules.control.units import reference # pylint: disable=unused-import default_device = None default_dtype = None unit_types = ['t2i adapter', 'controlnet', 'xs', 'lite', 'reference', 'ip'] current = [] class Unit(): # mashup of gradio controls and mapping to actual implementation classes def update_choices(self, model_id=None): name = model_id or self.model_name if name == 'InstantX Union F1': self.choices = ['canny', 'tile', 'depth', 'blur', 'pose', 'gray', 'lq'] elif name == 'Shakker-Labs Union F1': self.choices = ['canny', 'tile', 'depth', 'blur', 'pose', 'gray', 'lq'] elif name == 'Xinsir Union XL': self.choices = ['openpose', 'depth', 'scribble', 'canny', 'normal'] elif name == 'Xinsir ProMax XL': self.choices = ['openpose', 'depth', 'scribble', 'canny', 'normal', 'segment', 'tile', 'repaint'] else: self.choices = ['default'] def __str__(self): return f'Unit(index={self.index} enabled={self.enabled} type="{self.type}" strength={self.strength} start={self.start} end={self.end}{self.process}{self.controlnet} override={self.override})' def __init__(self, # values index: int = None, enabled: bool = None, strength: float = None, unit_type: str = None, start: float = 0, end: float = 1, # ui bindings enabled_cb = None, reset_btn = None, process_id = None, preview_btn = None, model_id = None, model_strength = None, preview_process = None, image_upload = None, image_reuse = None, image_preview = None, control_start = None, control_end = None, control_mode = None, control_tile = None, result_txt = None, extra_controls: list = [], ): self.model_id = model_id self.process_id = process_id self.controls = [gr.Label(value=unit_type, visible=False)] # separator self.index = index self.enabled = enabled or False self.type = unit_type self.strength = strength or 1.0 self.model_strength = model_strength self.start = start or 0 self.end = end or 1 self.start = min(self.start, self.end) self.end = max(self.start, self.end) self.mode = None # processor always exists, adapter and controlnet are optional self.model_name = None self.process_name = None self.process: processors.Processor = processors.Processor() self.adapter: t2iadapter.Adapter = None self.controlnet: Union[controlnet.ControlNet, xs.ControlNetXS] = None # map to input image self.override: Image = None # global settings but passed per-unit self.factor = 1.0 self.guess = False # reference settings self.attention = 'Attention' self.fidelity = 0.5 self.query_weight = 1.0 self.adain_weight = 1.0 # control mode self.choices = ['default'] # control tile self.tile = '1x1' def enabled_change(val): self.enabled = val def strength_change(val): self.strength = val def control_change(start, end): self.start = min(start, end) self.end = max(start, end) def control_mode_change(mode): self.mode = self.choices.index(mode) if mode is not None and mode in self.choices else 0 def control_tile_change(tile): self.tile = tile def control_choices(model_id): self.update_choices(model_id) mode_visible = 'union' in model_id.lower() or 'promax' in model_id.lower() tile_visible = 'union' in model_id.lower() or 'promax' in model_id.lower() or 'tile' in model_id.lower() return [gr.update(visible=mode_visible, choices=self.choices), gr.update(visible=tile_visible)] def adapter_extra(c1): self.factor = c1 def controlnet_extra(c1): self.guess = c1 def controlnetxs_extra(_c1): pass # gr.component passed directly to load method def reference_extra(c1, c2, c3, c4): self.attention = c1 self.fidelity = c2 self.query_weight = c3 self.adain_weight = c4 def upload_image(image_file): if image_file is None: self.process.override = None self.override = None log.debug('Control image: clear') return gr.update(value=None) try: self.process.override = Image.open(image_file.name) self.override = self.process.override log.debug(f'Control image: upload={self.process.override} path="{image_file.name}"') return gr.update(visible=self.process.override is not None, value=self.process.override) except Exception as e: log.error(f'Control image: upload path="{image_file.name}" error={e}') return gr.update(visible=False, value=None) def reuse_image(image): log.debug(f'Control process reuse image: {image}') self.process.override = image self.override = self.process.override return gr.update(visible=self.process.override is not None, value=self.process.override) def set_image(image): self.process.override = image self.override = image return gr.update(visible=image is not None) # actual init if self.type == 't2i adapter': self.adapter = t2iadapter.Adapter(device=default_device, dtype=default_dtype) elif self.type == 'controlnet': self.controlnet = controlnet.ControlNet(device=default_device, dtype=default_dtype) elif self.type == 'xs': self.controlnet = xs.ControlNetXS(device=default_device, dtype=default_dtype) elif self.type == 'lite': self.controlnet = lite.ControlLLLite(device=default_device, dtype=default_dtype) elif self.type == 'reference': pass elif self.type == 'ip': pass else: log.error(f'Control unknown type: unit={unit_type}') return # bind ui controls to properties if present if self.type == 't2i adapter': if model_id is not None: if isinstance(model_id, str): self.adapter.load(model_id) else: self.controls.append(model_id) model_id.change(fn=self.adapter.load, inputs=[model_id], outputs=[result_txt], show_progress='full') if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=adapter_extra, inputs=extra_controls) elif self.type == 'controlnet': if model_id is not None: if isinstance(model_id, str): self.controlnet.load(model_id) else: self.controls.append(model_id) model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress='full') model_id.change(fn=control_choices, inputs=[model_id], outputs=[control_mode, control_tile], show_progress='hidden') if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnet_extra, inputs=extra_controls) elif self.type == 'xs': if model_id is not None: if isinstance(model_id, str): self.controlnet.load(model_id) else: self.controls.append(model_id) model_id.change(fn=self.controlnet.load, inputs=[model_id, extra_controls[0]], outputs=[result_txt], show_progress='full') if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) elif self.type == 'lite': if model_id is not None: if isinstance(model_id, str): self.controlnet.load(model_id) else: self.controls.append(model_id) model_id.change(fn=self.controlnet.load, inputs=[model_id], outputs=[result_txt], show_progress='full') if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=controlnetxs_extra, inputs=extra_controls) elif self.type == 'reference': if extra_controls is not None and len(extra_controls) > 0: extra_controls[0].change(fn=reference_extra, inputs=extra_controls) extra_controls[1].change(fn=reference_extra, inputs=extra_controls) extra_controls[2].change(fn=reference_extra, inputs=extra_controls) extra_controls[3].change(fn=reference_extra, inputs=extra_controls) if enabled_cb is not None: self.controls.append(enabled_cb) enabled_cb.change(fn=enabled_change, inputs=[enabled_cb]) if model_strength is not None: self.controls.append(model_strength) model_strength.change(fn=strength_change, inputs=[model_strength]) if process_id is not None: if isinstance(process_id, str): self.process.load(process_id) else: self.controls.append(process_id) process_id.change(fn=self.process.load, inputs=[process_id], outputs=[result_txt], show_progress='full') if reset_btn is not None: reset_btn.click(fn=self.reset, inputs=[], outputs=[enabled_cb, model_id, process_id, model_strength]) if preview_btn is not None: preview_btn.click(fn=self.process.preview, inputs=[], outputs=[preview_process]) # return list of images for gallery if image_upload is not None: image_upload.upload(fn=upload_image, inputs=[image_upload], outputs=[image_preview]) # return list of images for gallery if image_reuse is not None: image_reuse.click(fn=reuse_image, inputs=[preview_process], outputs=[image_preview]) # return list of images for gallery if image_preview is not None: self.controls.append(image_preview) image_preview.change(fn=set_image, inputs=[image_preview], outputs=[image_preview]) if control_start is not None and control_end is not None: self.controls.append(control_start) self.controls.append(control_end) control_start.change(fn=control_change, inputs=[control_start, control_end]) control_end.change(fn=control_change, inputs=[control_start, control_end]) if control_mode is not None: self.controls.append(control_mode) control_mode.change(fn=control_mode_change, inputs=[control_mode]) if control_tile is not None: self.controls.append(control_tile) control_tile.change(fn=control_tile_change, inputs=[control_tile]) def reset(self): if self.process is not None: self.process.reset() if self.adapter is not None: self.adapter.reset() if self.controlnet is not None: self.controlnet.reset() self.override = None return [True, 'None', 'None', 1.0] # reset ui values ================================================ FILE: modules/control/units/controlnet.py ================================================ import os import time import threading from typing import Union from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, StableDiffusion3Pipeline, ControlNetModel from modules.control.units import detect from modules.shared import log, opts, cmd_opts, state, listdir from modules import errors, sd_models, devices, model_quant from modules.processing import StableDiffusionProcessingControl what = 'ControlNet' debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None debug_log = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None predefined_sd15 = { 'Canny': "lllyasviel/control_v11p_sd15_canny", 'Depth': "lllyasviel/control_v11f1p_sd15_depth", 'HED': "lllyasviel/sd-controlnet-hed", 'IP2P': "lllyasviel/control_v11e_sd15_ip2p", 'LineArt': "lllyasviel/control_v11p_sd15_lineart", 'LineArt Anime': "lllyasviel/control_v11p_sd15s2_lineart_anime", 'MLDS': "lllyasviel/control_v11p_sd15_mlsd", 'NormalBae': "lllyasviel/control_v11p_sd15_normalbae", 'OpenPose': "lllyasviel/control_v11p_sd15_openpose", 'Scribble': "lllyasviel/control_v11p_sd15_scribble", 'Segment': "lllyasviel/control_v11p_sd15_seg", 'Shuffle': "lllyasviel/control_v11e_sd15_shuffle", 'SoftEdge': "lllyasviel/control_v11p_sd15_softedge", 'Tile': "lllyasviel/control_v11f1e_sd15_tile", 'Depth Anything': 'vladmandic/depth-anything', 'Canny FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_canny.safetensors', 'Inpaint FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_inpaint.safetensors', 'LineArt Anime FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_animeline.safetensors', 'LineArt FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_lineart.safetensors', 'MLSD FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_mlsd.safetensors', 'NormalBae FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_normal.safetensors', 'OpenPose FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_openpose.safetensors', 'Pix2Pix FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_pix2pix.safetensors', 'Scribble FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_scribble.safetensors', 'Segment FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_seg.safetensors', 'Shuffle FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_shuffle.safetensors', 'SoftEdge FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_softedge.safetensors', 'Tile FP16': 'Aptronym/SDNext/ControlNet11/controlnet11Models_tileE.safetensors', 'CiaraRowles TemporalNet': "CiaraRowles/TemporalNet", 'Ciaochaos Recolor': 'ioclab/control_v1p_sd15_brightness', 'Ciaochaos Illumination': 'ioclab/control_v1u_sd15_illumination/illumination20000.safetensors', } predefined_sdxl = { 'Canny Small XL': 'diffusers/controlnet-canny-sdxl-1.0-small', 'Canny Mid XL': 'diffusers/controlnet-canny-sdxl-1.0-mid', 'Canny XL': 'diffusers/controlnet-canny-sdxl-1.0', 'Depth Zoe XL': 'diffusers/controlnet-zoe-depth-sdxl-1.0', 'Depth Mid XL': 'diffusers/controlnet-depth-sdxl-1.0-mid', 'OpenPose XL': 'thibaud/controlnet-openpose-sdxl-1.0/bin', 'Xinsir Union XL': 'xinsir/controlnet-union-sdxl-1.0', 'Xinsir ProMax XL': 'brad-twinkl/controlnet-union-sdxl-1.0-promax', 'Xinsir OpenPose XL': 'xinsir/controlnet-openpose-sdxl-1.0', 'Xinsir Canny XL': 'xinsir/controlnet-canny-sdxl-1.0', 'Xinsir Depth XL': 'xinsir/controlnet-depth-sdxl-1.0', 'Xinsir Scribble XL': 'xinsir/controlnet-scribble-sdxl-1.0', 'Xinsir Anime Painter XL': 'xinsir/anime-painter', 'Xinsir Tile XL': 'xinsir/controlnet-tile-sdxl-1.0', 'NoobAI Canny XL': 'Eugeoter/noob-sdxl-controlnet-canny', 'NoobAI Lineart Anime XL': 'Eugeoter/noob-sdxl-controlnet-lineart_anime', 'NoobAI Depth XL': 'Eugeoter/noob-sdxl-controlnet-depth', 'NoobAI Normal XL': 'Eugeoter/noob-sdxl-controlnet-normal', 'NoobAI SoftEdge XL': 'Eugeoter/noob-sdxl-controlnet-softedge_hed', 'NoobAI OpenPose XL': 'einar77/noob-openpose', 'TTPlanet Tile Realistic XL': 'Yakonrus/SDXL_Controlnet_Tile_Realistic_v2', # 'StabilityAI Canny R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-canny-rank128.safetensors', # 'StabilityAI Depth R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-depth-rank128.safetensors', # 'StabilityAI Recolor R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-recolor-rank128.safetensors', # 'StabilityAI Sketch R128': 'stabilityai/control-lora/control-LoRAs-rank128/control-lora-sketch-rank128-metadata.safetensors', # 'StabilityAI Canny R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-canny-rank256.safetensors', # 'StabilityAI Depth R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-depth-rank256.safetensors', # 'StabilityAI Recolor R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors', # 'StabilityAI Sketch R256': 'stabilityai/control-lora/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors', } predefined_f1 = { "InstantX Union F1": 'InstantX/FLUX.1-dev-Controlnet-Union', "InstantX Canny F1": 'InstantX/FLUX.1-dev-Controlnet-Canny', "JasperAI Depth F1": 'jasperai/Flux.1-dev-Controlnet-Depth', "BlackForrestLabs Canny LoRA F1": '/huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors', "BlackForrestLabs Depth LoRA F1": '/huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors', "JasperAI Surface Normals F1": 'jasperai/Flux.1-dev-Controlnet-Surface-Normals', "JasperAI Upscaler F1": 'jasperai/Flux.1-dev-Controlnet-Upscaler', "Shakker-Labs Union F1": 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro', "Shakker-Labs Pose F1": 'Shakker-Labs/FLUX.1-dev-ControlNet-Pose', "Shakker-Labs Depth F1": 'Shakker-Labs/FLUX.1-dev-ControlNet-Depth', "XLabs-AI Canny F1": 'XLabs-AI/flux-controlnet-canny-diffusers', "XLabs-AI Depth F1": 'XLabs-AI/flux-controlnet-depth-diffusers', "XLabs-AI HED F1": 'XLabs-AI/flux-controlnet-hed-diffusers', "LibreFlux Segment F1": 'neuralvfx/LibreFlux-ControlNet', } predefined_sd3 = { "StabilityAI Canny SD35": 'diffusers-internal-dev/sd35-controlnet-canny-8b', "StabilityAI Depth SD35": 'diffusers-internal-dev/sd35-controlnet-depth-8b', "StabilityAI Blur SD35": 'diffusers-internal-dev/sd35-controlnet-blur-8b', "InstantX Canny SD35": 'InstantX/SD3-Controlnet-Canny', "InstantX Pose SD35": 'InstantX/SD3-Controlnet-Pose', "InstantX Depth SD35": 'InstantX/SD3-Controlnet-Depth', "InstantX Tile SD35": 'InstantX/SD3-Controlnet-Tile', "Alimama Inpainting SD35": 'alimama-creative/SD3-Controlnet-Inpainting', "Alimama SoftEdge SD35": 'alimama-creative/SD3-Controlnet-Softedge', } predefined_qwen = { "InstantX Union Qwen": 'InstantX/Qwen-Image-ControlNet-Union', } predefined_hunyuandit = { "HunyuanDiT Canny": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Canny', "HunyuanDiT Pose": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Pose', "HunyuanDiT Depth": 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Depth', } predefined_zimage = { "Z-Image-Turbo Union 1.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union', "Z-Image-Turbo Union 2.0": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0', "Z-Image-Turbo Union 2.1": 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1', } variants = { 'NoobAI Canny XL': 'fp16', 'NoobAI Lineart Anime XL': 'fp16', 'NoobAI Depth XL': 'fp16', 'NoobAI Normal XL': 'fp16', 'NoobAI SoftEdge XL': 'fp16', 'TTPlanet Tile Realistic XL': 'fp16', } subfolders = { "LibreFlux Segment F1": 'controlnet', } remote_code = { "LibreFlux Segment F1": True, } models = {} all_models = {} all_models.update(predefined_sd15) all_models.update(predefined_sdxl) all_models.update(predefined_f1) all_models.update(predefined_sd3) all_models.update(predefined_qwen) all_models.update(predefined_hunyuandit) all_models.update(predefined_zimage) cache_dir = 'models/control/controlnet' load_lock = threading.Lock() def find_models(): path = os.path.join(opts.control_dir, 'controlnet') files = listdir(path) folders = [f for f in files if os.path.isdir(f) if os.path.exists(os.path.join(f, 'config.json'))] files = [f for f in files if f.endswith('.safetensors')] downloaded_models = {} for f in files: basename = os.path.splitext(os.path.relpath(f, path))[0] downloaded_models[basename] = f for f in folders: basename = os.path.relpath(f, path) downloaded_models[basename] = f all_models.update(downloaded_models) return downloaded_models find_models() def api_list_models(model_type: str = None): import modules.shared model_type = model_type or modules.shared.sd_model_type model_list = [] if model_type == 'sd' or model_type == 'all': model_list += list(predefined_sd15) if model_type == 'sdxl' or model_type == 'all': model_list += list(predefined_sdxl) if model_type == 'f1' or model_type == 'all': model_list += list(predefined_f1) if model_type == 'sd3' or model_type == 'all': model_list += list(predefined_sd3) if model_type == 'qwen' or model_type == 'all': model_list += list(predefined_qwen) if model_type == 'hunyuandit' or model_type == 'all': model_list += list(predefined_hunyuandit) if model_type == 'zimage': model_list += list(predefined_zimage) model_list += sorted(find_models()) return model_list def list_models(refresh=False): import modules.shared global models # pylint: disable=global-statement if not refresh and len(models) > 0: return models models = {} if modules.shared.sd_model_type == 'none': models = ['None'] elif modules.shared.sd_model_type == 'sdxl': models = ['None'] + list(predefined_sdxl) + sorted(find_models()) elif modules.shared.sd_model_type == 'sd': models = ['None'] + list(predefined_sd15) + sorted(find_models()) elif modules.shared.sd_model_type == 'f1': models = ['None'] + list(predefined_f1) + sorted(find_models()) elif modules.shared.sd_model_type == 'sd3': models = ['None'] + list(predefined_sd3) + sorted(find_models()) elif modules.shared.sd_model_type == 'qwen': models = ['None'] + list(predefined_qwen) + sorted(find_models()) elif modules.shared.sd_model_type == 'hunyuandit': models = ['None'] + list(predefined_hunyuandit) + sorted(find_models()) elif modules.shared.sd_model_type == 'zimage': models = ['None'] + list(predefined_zimage) + sorted(find_models()) else: log.warning(f'Control {what} model list failed: unknown model type') models = ['None'] + list(all_models) + sorted(find_models()) debug_log(f'Control list {what}: path={cache_dir} models={models}') return models class ControlNet(): def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): self.model: ControlNetModel = None self.model_id: str = model_id self.device = device self.dtype = dtype self.load_config = { 'cache_dir': cache_dir } if load_config is not None: self.load_config.update(load_config) if opts.offline_mode: self.load_config["local_files_only"] = True os.environ['HF_HUB_OFFLINE'] = '1' else: os.environ.pop('HF_HUB_OFFLINE', None) os.unsetenv('HF_HUB_OFFLINE') if model_id is not None: self.load() def __str__(self): return f' ControlNet(id={self.model_id} model={self.model.__class__.__name__})' if self.model_id and self.model else '' def reset(self): if self.model is not None: debug_log(f'Control {what} model unloaded') self.model = None self.model_id = None devices.torch_gc(force=True, reason='controlnet') def get_class(self, model_id:str=''): from modules import shared if shared.sd_model_type == 'none': _load = shared.sd_model # trigger a load if shared.sd_model_type == 'sd': from diffusers import ControlNetModel as cls # pylint: disable=reimported config = 'lllyasviel/control_v11p_sd15_canny' elif shared.sd_model_type == 'sdxl': if 'union' in model_id.lower(): from diffusers import ControlNetUnionModel as cls config = 'xinsir/controlnet-union-sdxl-1.0' elif 'promax' in model_id.lower(): from diffusers import ControlNetUnionModel as cls config = 'brad-twinkl/controlnet-union-sdxl-1.0-promax' else: from diffusers import ControlNetModel as cls # pylint: disable=reimported # sdxl shares same model class config = 'Eugeoter/noob-sdxl-controlnet-canny' elif shared.sd_model_type == 'f1': from diffusers import FluxControlNetModel as cls config = 'InstantX/FLUX.1-dev-Controlnet-Union' elif shared.sd_model_type == 'sd3': from diffusers import SD3ControlNetModel as cls config = 'InstantX/SD3-Controlnet-Canny' elif shared.sd_model_type == 'qwen': from diffusers import QwenImageControlNetModel as cls config = 'InstantX/Qwen-Image-ControlNet-Union' elif shared.sd_model_type == 'hunyuandit': from diffusers import HunyuanDiT2DControlNetModel as cls config = 'Tencent-Hunyuan/HunyuanDiT-v1.2-ControlNet-Diffusers-Canny' elif shared.sd_model_type == 'zimage': from diffusers import ZImageControlNetModel as cls if '2.0' in model_id: config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0' elif '2.1' in model_id: config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1' else: config = 'hlky/Z-Image-Turbo-Fun-Controlnet-Union' else: log.error(f'Control {what}: type={shared.sd_model_type} unsupported model') return None, None return cls, config def load_safetensors(self, model_id, model_path, cls, config): # pylint: disable=unused-argument name = os.path.splitext(model_path)[0] config_path = None if not os.path.exists(model_path): import huggingface_hub as hf parts = model_path.split('/') repo_id = f'{parts[0]}/{parts[1]}' filename = os.path.splitext('/'.join(parts[2:]))[0] model_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.safetensors', cache_dir=cache_dir) if config_path is None: try: config_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.yaml', cache_dir=cache_dir) except Exception: pass # no yaml file if config_path is None: try: config_path = hf.hf_hub_download(repo_id=repo_id, filename=f'{filename}.json', cache_dir=cache_dir) except Exception: pass # no yaml file elif os.path.exists(name + '.yaml'): config_path = f'{name}.yaml' elif os.path.exists(name + '.json'): config_path = f'{name}.json' if config_path is not None: self.load_config['original_config_file '] = config_path self.model = cls.from_single_file(model_path, config=config, **self.load_config) def load(self, model_id: str = None, force: bool = False) -> str: with load_lock: try: t0 = time.time() model_id = model_id or self.model_id if model_id is None or model_id == 'None': self.reset() return if model_id not in all_models: log.error(f'Control {what}: id="{model_id}" available={list(all_models)} unknown model') return model_path = all_models[model_id] if model_path == '': return if model_path is None: log.error(f'Control {what} model load: id="{model_id}" unknown model id') return if 'lora' in model_id.lower(): self.model = model_path return if model_id == self.model_id and not force: # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') return log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') cls, config = self.get_class(model_id) if cls is None: log.error(f'Control {what} model load: id="{model_id}" unknown base model') return self.reset() jobid = state.begin(f'Load {what}') if model_path.endswith('.safetensors'): self.load_safetensors(model_id, model_path, cls, config) else: kwargs = {} if '/bin' in model_path: model_path = model_path.replace('/bin', '') self.load_config['use_safetensors'] = False else: self.load_config['use_safetensors'] = True if variants.get(model_id, None) is not None: kwargs['variant'] = variants[model_id] if subfolders.get(model_id, None) is not None: kwargs['subfolder'] = subfolders[model_id] if remote_code.get(model_id, None) is not None: kwargs['trust_remote_code'] = remote_code[model_id] try: self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs) except Exception as e: log.error(f'Control {what} model load: id="{model_id}" {e}') if debug: errors.display(e, 'Control') if self.model is None: return if not cmd_opts.lowvram: # lowvram will cause unet<->controlnet to ping-pong but saves more memory self.model.offload_never = True if self.dtype is not None: self.model.to(self.dtype) if self.device is not None: if (opts.diffusers_offload_mode != 'balanced') and hasattr(self.model, 'to'): try: self.model.to(self.device) except Exception as e: if 'Cannot copy out of meta tensor' in str(e): self.model.to_empty(device=self.device) if "Control" in opts.sdnq_quantize_weights: try: log.debug(f'Control {what} model SDNQ quantize: id="{model_id}"') from modules.model_quant import sdnq_quantize_model self.model = sdnq_quantize_model(self.model) except Exception as e: log.error(f'Control {what} model SDNQ Compression failed: id="{model_id}" {e}') elif "Control" in opts.optimum_quanto_weights: try: log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"') model_quant.load_quanto('Load model: type=Control') from modules.model_quant import optimum_quanto_model self.model = optimum_quanto_model(self.model) except Exception as e: log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}') elif "Control" in opts.torchao_quantization: try: log.debug(f'Control {what} model Torch AO: id="{model_id}"') model_quant.load_torchao('Load model: type=Control') from modules.model_quant import torchao_quantization self.model = torchao_quantization(self.model) except Exception as e: log.error(f'Control {what} model Torch AO: id="{model_id}" {e}') if self.device is not None: sd_models.move_model(self.model, self.device) if "Control" in opts.cuda_compile: try: from modules.sd_models_compile import compile_torch self.model = compile_torch(self.model, apply_to_components=False, op="Control") except Exception as e: log.warning(f"Control compile error: {e}") t1 = time.time() self.model_id = model_id log.info(f'Control {what} model loaded: id="{self.model_id}" path="{model_path}" cls={cls.__name__} time={t1-t0:.2f}') state.end(jobid) return f'{what} loaded model: {self.model_id}' except Exception as e: log.error(f'Control {what} model load: id="{model_id}" {e}') errors.display(e, f'Control {what} load') return f'{what} failed to load model: {model_id}' class ControlNetPipeline(): def __init__(self, controlnet: Union[ControlNetModel, list[ControlNetModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline, FluxPipeline, StableDiffusion3Pipeline], dtype = None, p: StableDiffusionProcessingControl = None, # pylint: disable=unused-argument ): t0 = time.time() self.orig_pipeline = pipeline self.pipeline = None controlnets = controlnet if isinstance(controlnet, list) else [controlnet] loras = [cn for cn in controlnets if isinstance(cn, str)] controlnets = [cn for cn in controlnets if not isinstance(cn, str)] if pipeline is None: log.error('Control model pipeline: model not loaded') return elif detect.is_sdxl(pipeline) and len(controlnets) > 0: from diffusers import StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetUnionPipeline classes = [c.__class__.__name__ for c in controlnets] if any(c == 'ControlNetUnionModel' for c in classes): if not all(c == 'ControlNetUnionModel' for c in classes): log.warning(f'Control {what}: units={classes} mixed type is not supported') return if isinstance(controlnets, list) and len(controlnets) == 1: controlnets = controlnets[0] cls = StableDiffusionXLControlNetUnionPipeline else: cls = StableDiffusionXLControlNetPipeline self.pipeline = cls( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), image_encoder=getattr(pipeline, 'image_encoder', None), controlnet=controlnets, # can be a list ) elif detect.is_f1(pipeline) and len(controlnets) > 0: from diffusers import FluxControlNetPipeline self.pipeline = FluxControlNetPipeline( vae=pipeline.vae.to(devices.device), text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, transformer=pipeline.transformer, scheduler=pipeline.scheduler, controlnet=controlnets, # can be a list ) elif detect.is_sd3(pipeline) and len(controlnets) > 0: from diffusers import StableDiffusion3ControlNetPipeline self.pipeline = StableDiffusion3ControlNetPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, text_encoder_3=pipeline.text_encoder_3, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, tokenizer_3=pipeline.tokenizer_3, transformer=pipeline.transformer, scheduler=pipeline.scheduler, controlnet=controlnets, # can be a list ) elif detect.is_sd15(pipeline) and len(controlnets) > 0: from diffusers import StableDiffusionControlNetPipeline self.pipeline = StableDiffusionControlNetPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), image_encoder=getattr(pipeline, 'image_encoder', None), requires_safety_checker=False, safety_checker=None, controlnet=controlnets, # can be a list ) sd_models.move_model(self.pipeline, pipeline.device) elif detect.is_qwen(pipeline) and len(controlnets) > 0: from diffusers import QwenImageControlNetPipeline self.pipeline = QwenImageControlNetPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, transformer=pipeline.transformer, scheduler=pipeline.scheduler, controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list ) elif detect.is_hunyuandit(pipeline) and len(controlnets) > 0: from diffusers import HunyuanDiTControlNetPipeline self.pipeline = HunyuanDiTControlNetPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, text_encoder_2=pipeline.text_encoder_2, tokenizer_2=pipeline.tokenizer_2, transformer=pipeline.transformer, scheduler=pipeline.scheduler, safety_checker=None, feature_extractor=None, controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list ) elif detect.is_zimage(pipeline) and len(controlnets) > 0: from diffusers import ZImageControlNetPipeline self.pipeline = ZImageControlNetPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, transformer=pipeline.transformer, scheduler=pipeline.scheduler, controlnet=controlnets[0] if isinstance(controlnets, list) else controlnets, # can be a list ) self.pipeline.task_args = { 'guidance_scale': 1 } elif len(loras) > 0: self.pipeline = pipeline for lora in loras: log.debug(f'Control {what} pipeline: lora="{lora}"') lora = lora.replace('/huggingface.co/', '') self.pipeline.load_lora_weights(lora) """ if p is not None: p.prompt += f'' """ else: log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') return if self.pipeline is None: log.error(f'Control {what} pipeline: not initialized') return if dtype is not None: self.pipeline = self.pipeline.to(dtype) controlnet = None # free up memory controlnets = None sd_models.copy_diffuser_options(self.pipeline, pipeline) if opts.diffusers_offload_mode == 'none': sd_models.move_model(self.pipeline, devices.device) sd_models.clear_caches() sd_models.set_diffuser_offload(self.pipeline, 'model', force=True) t1 = time.time() debug_log(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') def restore(self): if self.pipeline is not None and hasattr(self.pipeline, 'unload_lora_weights'): self.pipeline.unload_lora_weights() self.pipeline = None return self.orig_pipeline ================================================ FILE: modules/control/units/detect.py ================================================ def is_compatible(model, pattern='None'): if model is None: return False if hasattr(model, '__class__'): return model.__class__.__name__.startswith(pattern) return False def is_sd15(model): return is_compatible(model, pattern='StableDiffusion') def is_sdxl(model): return is_compatible(model, pattern='StableDiffusionXL') def is_f1(model): return is_compatible(model, pattern='Flux') def is_sd3(model): return is_compatible(model, pattern='StableDiffusion3Pipeline') def is_qwen(model): return is_compatible(model, pattern='Qwen') def is_hunyuandit(model): return is_compatible(model, pattern='HunyuanDiT') def is_zimage(model): return is_compatible(model, pattern='ZImage') ================================================ FILE: modules/control/units/lite.py ================================================ import os import time from typing import Union import threading import numpy as np from PIL import Image from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from modules.shared import log, opts, listdir from modules import errors from modules.control.units.lite_model import ControlNetLLLite what = 'ControlLLLite' debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: CONTROL') predefined_sd15 = { } predefined_sdxl = { 'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny', 'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime', 'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime', 'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta', 'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime', 'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2', } models = {} all_models = {} all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/lite' load_lock = threading.Lock() def find_models(): path = os.path.join(opts.control_dir, 'lite') files = listdir(path) files = [f for f in files if f.endswith('.safetensors')] downloaded_models = {} for f in files: basename = os.path.splitext(os.path.relpath(f, path))[0] downloaded_models[basename] = os.path.join(path, f) all_models.update(downloaded_models) return downloaded_models def list_models(refresh=False): import modules.shared global models # pylint: disable=global-statement if not refresh and len(models) > 0: return models models = {} if modules.shared.sd_model_type == 'none': models = ['None'] elif modules.shared.sd_model_type == 'sdxl': models = ['None'] + sorted(predefined_sdxl) + sorted(find_models()) elif modules.shared.sd_model_type == 'sd': models = ['None'] + sorted(predefined_sd15) + sorted(find_models()) else: log.warning(f'Control {what} model list failed: unknown model type') models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) debug(f'Control list {what}: path={cache_dir} models={models}') return models class ControlLLLite(): def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): self.model: ControlNetLLLite = None self.model_id: str = model_id self.device = device self.dtype = dtype self.load_config = { 'cache_dir': cache_dir } if load_config is not None: self.load_config.update(load_config) if model_id is not None: self.load() def __str__(self): return f' ControlLLLite(id={self.model_id} model={self.model.__class__.__name__})' if self.model_id and self.model else '' def reset(self): if self.model is not None: debug(f'Control {what} model unloaded') self.model = None self.model_id = None def load(self, model_id: str = None, force: bool = True) -> str: with load_lock: try: t0 = time.time() model_id = model_id or self.model_id if model_id is None or model_id == 'None': self.reset() return if model_id not in all_models: log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') return model_path = all_models[model_id] if model_path == '': return if model_path is None: log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') return if model_id == self.model_id and not force: # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') return log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') if model_path.endswith('.safetensors'): self.model = ControlNetLLLite(model_path) else: import huggingface_hub as hf offline_config = {} if opts.offline_mode: offline_config["local_files_only"] = True os.environ['HF_HUB_OFFLINE'] = '1' else: os.environ.pop('HF_HUB_OFFLINE', None) os.unsetenv('HF_HUB_OFFLINE') folder, filename = os.path.split(model_path) model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir, **offline_config) self.model = ControlNetLLLite(model_path) if self.device is not None: self.model.to(self.device) if self.dtype is not None: self.model.to(self.dtype) t1 = time.time() self.model_id = model_id log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') return f'{what} loaded model: {model_id}' except Exception as e: log.error(f'Control {what} model load failed: id="{model_id}" error={e}') errors.display(e, f'Control {what} load') return f'{what} failed to load model: {model_id}' class ControlLLitePipeline(): def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]): self.pipeline = pipeline # self.pipeline.__class__.__name__ = 'ControlLLLitePipeline' self.nets = [] def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning): if image is None: return self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}') weight = [conditioning] if isinstance(conditioning, float) else conditioning images = [image] if isinstance(image, Image.Image) else image images = [i.convert('RGB') for i in images] for i, cn in enumerate(self.nets): cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)]) def restore(self): from modules.control.units.lite_model import clear_all_lllite clear_all_lllite() self.nets = [] ================================================ FILE: modules/control/units/lite_model.py ================================================ # Credits: # import re import torch from safetensors.torch import load_file all_hack = {} class LLLiteModule(torch.nn.Module): def __init__( self, name: str, is_conv2d: bool, in_dim: int, depth: int, cond_emb_dim: int, mlp_dim: int, ): super().__init__() self.name = name self.is_conv2d = is_conv2d self.is_first = False modules = [] modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2 if depth == 1: modules.append(torch.nn.ReLU(inplace=True)) modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) elif depth == 2: modules.append(torch.nn.ReLU(inplace=True)) modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) elif depth == 3: # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 modules.append(torch.nn.ReLU(inplace=True)) modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) modules.append(torch.nn.ReLU(inplace=True)) modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) self.conditioning1 = torch.nn.Sequential(*modules) if self.is_conv2d: self.down = torch.nn.Sequential( torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), torch.nn.ReLU(inplace=True), ) self.mid = torch.nn.Sequential( torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), torch.nn.ReLU(inplace=True), ) self.up = torch.nn.Sequential( torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), ) else: self.down = torch.nn.Sequential( torch.nn.Linear(in_dim, mlp_dim), torch.nn.ReLU(inplace=True), ) self.mid = torch.nn.Sequential( torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), torch.nn.ReLU(inplace=True), ) self.up = torch.nn.Sequential( torch.nn.Linear(mlp_dim, in_dim), ) self.depth = depth self.cond_image = None self.cond_emb = None def set_cond_image(self, cond_image): self.cond_image = cond_image self.cond_emb = None def forward(self, x): if self.cond_emb is None: cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype)) # if blk_shape is not None: # b, c, h, w = blk_shape # cx = torch.nn.functional.interpolate(cx, (h, w), mode="nearest-exact") if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c n, c, h, w = cx.shape cx = cx.view(n, c, h * w).permute(0, 2, 1) self.cond_emb = cx cx = self.cond_emb # uncond/condでxはバッチサイズが2倍 if x.shape[0] != cx.shape[0]: if self.is_conv2d: cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1) else: cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1) cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) cx = self.mid(cx) cx = self.up(cx) return cx def clear_all_lllite(): global all_hack # pylint: disable=global-statement for k, v in all_hack.items(): k.forward = v k.lllite_list = [] all_hack = {} return class ControlNetLLLite(torch.nn.Module): # pylint: disable=abstract-method def __init__(self, path: str): super().__init__() module_weights = {} try: state_dict = load_file(path) except Exception as e: raise RuntimeError(f"Failed to load {path}") from e for key, value in state_dict.items(): fragments = key.split(".") module_name = fragments[0] weight_name = ".".join(fragments[1:]) if module_name not in module_weights: module_weights[module_name] = {} module_weights[module_name][weight_name] = value modules = {} for module_name, weights in module_weights.items(): if "conditioning1.4.weight" in weights: depth = 3 elif weights["conditioning1.2.weight"].shape[-1] == 4: depth = 2 else: depth = 1 module = LLLiteModule( name=module_name, is_conv2d=weights["down.0.weight"].ndim == 4, in_dim=weights["down.0.weight"].shape[1], depth=depth, cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2, mlp_dim=weights["down.0.weight"].shape[0], ) # info = module.load_state_dict(weights) modules[module_name] = module setattr(self, module_name, module) if len(modules) == 1: module.is_first = True self.modules = modules return @torch.no_grad() def apply(self, pipe, cond, weight): # pylint: disable=arguments-differ map_down_lllite_to_unet = {4: (1, 0), 5: (1, 1), 7: (2, 0), 8: (2, 1)} model = pipe.unet if type(cond) != torch.Tensor: cond = torch.tensor(cond) cond = cond/255 # 0-255 -> 0-1 cond_image = cond.unsqueeze(dim=0).permute(0, 3, 1, 2) # h,w,c -> b,c,h,w cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-1 for module in self.modules.values(): module.set_cond_image(cond_image) for k, v in self.modules.items(): k = k.replace('middle_block', 'middle_blocks_0') match = re.match("lllite_unet_(.*)_blocks_(.*)_1_transformer_blocks_(.*)_(.*)_to_(.*)", k, re.M | re.I) assert match, 'Failed to load ControlLLLite!' root = match.group(1) block = match.group(2) block_number = match.group(3) attn_name = match.group(4) proj_name = match.group(5) if root == 'input': mapped_block, mapped_number = map_down_lllite_to_unet[int(block)] b = model.down_blocks[mapped_block].attentions[int(mapped_number)].transformer_blocks[int(block_number)] elif root == 'output': pass # not implemented else: b = model.mid_block.attentions[0].transformer_blocks[int(block_number)] b = getattr(b, attn_name, None) assert b is not None, 'Failed to load ControlLLLite!' b = getattr(b, 'to_' + proj_name, None) assert b is not None, 'Failed to load ControlLLLite!' if not hasattr(b, 'lllite_list'): b.lllite_list = [] if len(b.lllite_list) == 0: all_hack[b] = b.forward b.forward = self.get_hacked_forward(original_forward=b.forward, model=model, blk=b) b.lllite_list.append((weight, v)) return def get_hacked_forward(self, original_forward, model, blk): @torch.no_grad() def forward(x, **kwargs): hack = 0 for weight, module in blk.lllite_list: module.to(x.device) module.to(x.dtype) hack = hack + module(x) * weight x = x + hack return original_forward(x, **kwargs) return forward ================================================ FILE: modules/control/units/reference.py ================================================ from typing import Union import time import diffusers.utils from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from modules.shared import log, opts from modules.control.units import detect from modules import sd_models what = 'Reference' def list_models(): return ['Reference'] class ReferencePipeline(): def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): t0 = time.time() self.orig_pipeline = pipeline self.pipeline = None if pipeline is None: log.error(f'Control {what} model pipeline: model not loaded') return if opts.diffusers_fuse_projections and hasattr(pipeline, 'unfuse_qkv_projections'): pipeline.unfuse_qkv_projections() if detect.is_sdxl(pipeline): cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_xl_reference', module_file='pipeline.py') self.pipeline = cls( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), ) sd_models.move_model(self.pipeline, pipeline.device) elif detect.is_sd15(pipeline): cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_reference', module_file='pipeline.py') self.pipeline = cls( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), requires_safety_checker=False, safety_checker=None, ) sd_models.move_model(self.pipeline, pipeline.device) else: log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') return if dtype is not None and self.pipeline is not None: self.pipeline = self.pipeline.to(dtype) t1 = time.time() if self.pipeline is not None: log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') else: log.error(f'Control {what} pipeline: not initialized') def restore(self): self.pipeline = None return self.orig_pipeline ================================================ FILE: modules/control/units/t2iadapter.py ================================================ import os import time from typing import Union import threading from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, MultiAdapter, StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline # pylint: disable=unused-import from installer import log from modules import errors, sd_models from modules.control.units import detect what = 'T2I-Adapter' debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: CONTROL') predefined_sd15 = { 'Segment': ('TencentARC/t2iadapter_seg_sd14v1', {}), 'Zoe Depth': ('TencentARC/t2iadapter_zoedepth_sd15v1', {}), 'OpenPose': ('TencentARC/t2iadapter_openpose_sd14v1', {}), 'KeyPose': ('TencentARC/t2iadapter_keypose_sd14v1', {}), 'Color': ('TencentARC/t2iadapter_color_sd14v1', {}), 'Depth v1': ('TencentARC/t2iadapter_depth_sd14v1', {}), 'Depth v2': ('TencentARC/t2iadapter_depth_sd15v2', {}), 'Canny v1': ('TencentARC/t2iadapter_canny_sd14v1', {}), 'Canny v2': ('TencentARC/t2iadapter_canny_sd15v2', {}), 'Sketch v1': ('TencentARC/t2iadapter_sketch_sd14v1', {}), 'Sketch v2': ('TencentARC/t2iadapter_sketch_sd15v2', {}), # 'Coadapter Canny': 'TencentARC/T2I-Adapter/models/coadapter-canny-sd15v1.pth', # 'Coadapter Color': 'TencentARC/T2I-Adapter/models/coadapter-color-sd15v1.pth', # 'Coadapter Depth': 'TencentARC/T2I-Adapter/models/coadapter-depth-sd15v1.pth', # 'Coadapter Fuser': 'TencentARC/T2I-Adapter/models/coadapter-fuser-sd15v1.pth', # 'Coadapter Sketch': 'TencentARC/T2I-Adapter/models/coadapter-sketch-sd15v1.pth', # 'Coadapter Style': 'TencentARC/T2I-Adapter/models/coadapter-style-sd15v1.pth', } predefined_sdxl = { 'Canny XL': ('TencentARC/t2i-adapter-canny-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), 'LineArt XL': ('TencentARC/t2i-adapter-lineart-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), 'Sketch XL': ('TencentARC/t2i-adapter-sketch-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), 'Zoe Depth XL': ('TencentARC/t2i-adapter-depth-zoe-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), 'OpenPose XL': ('TencentARC/t2i-adapter-openpose-sdxl-1.0', { 'use_safetensors': True }), 'Midas Depth XL': ('TencentARC/t2i-adapter-depth-midas-sdxl-1.0', { 'use_safetensors': True, 'variant': 'fp16' }), } models = {} all_models = {} all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/adapter' load_lock = threading.Lock() def list_models(refresh=False): import modules.shared global models # pylint: disable=global-statement if not refresh and len(models) > 0: return models models = {} if modules.shared.sd_model_type == 'none': models = ['None'] elif modules.shared.sd_model_type == 'sdxl': models = ['None'] + sorted(predefined_sdxl) elif modules.shared.sd_model_type == 'sd': models = ['None'] + sorted(predefined_sd15) else: log.warning(f'Control {what} model list failed: unknown model type') models = ['None'] + sorted(list(predefined_sd15) + list(predefined_sdxl)) debug(f'Control list {what}: path={cache_dir} models={models}') return models class AdapterModel(T2IAdapter): pass class Adapter(): def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): self.model: AdapterModel = None self.model_id: str = model_id self.device = device self.dtype = dtype self.load_config = { 'cache_dir': cache_dir, 'use_safetensors': False } if load_config is not None: self.load_config.update(load_config) if model_id is not None: self.load() def __str__(self): return f' T2IAdapter(id={self.model_id} model={self.model.__class__.__name__})' if self.model_id and self.model else '' def reset(self): if self.model is not None: debug(f'Control {what} model unloaded') self.model = None self.model_id = None def load(self, model_id: str = None, force: bool = True) -> str: with load_lock: try: t0 = time.time() model_id = model_id or self.model_id if model_id is None or model_id == 'None': self.reset() return if model_id not in all_models: log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') return model_path, model_args = all_models[model_id] self.load_config.update(model_args) from modules.shared import opts if opts.offline_mode: self.load_config["local_files_only"] = True os.environ['HF_HUB_OFFLINE'] = '1' else: os.environ.pop('HF_HUB_OFFLINE', None) os.unsetenv('HF_HUB_OFFLINE') if model_path is None: log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') return if model_id == self.model_id and not force: # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') return log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}"') if model_path.endswith('.pth') or model_path.endswith('.pt') or model_path.endswith('.safetensors') or model_path.endswith('.bin'): from huggingface_hub import hf_hub_download parts = model_path.split('/') repo_id = f'{parts[0]}/{parts[1]}' filename = '/'.join(parts[2:]) model = hf_hub_download(repo_id, filename, **self.load_config) self.model = T2IAdapter.from_pretrained(model, **self.load_config) else: self.model = T2IAdapter.from_pretrained(model_path, **self.load_config) if self.device is not None: self.model.to(self.device) if self.dtype is not None: self.model.to(self.dtype) t1 = time.time() self.model_id = model_id log.debug(f'Control {what} loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') return f'{what} loaded model: {model_id}' except Exception as e: log.error(f'Control {what} model load failed: id="{model_id}" error={e}') errors.display(e, f'Control {what} load') return f'{what} failed to load model: {model_id}' class AdapterPipeline(): def __init__(self, adapter: Union[T2IAdapter, list[T2IAdapter]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): t0 = time.time() self.orig_pipeline = pipeline self.pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline] = None if pipeline is None: log.error(f'Control {what} pipeline: model not loaded') return if isinstance(adapter, list) and len(adapter) > 1: adapter = MultiAdapter(adapter) adapter.to(device=pipeline.device, dtype=pipeline.dtype) """ pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["sd-t2iadapter"] = StableDiffusionAdapterPipeline pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["sd-t2iadapter"] = StableDiffusionAdapterPipeline pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["sd-t2iadapter"] = StableDiffusionAdapterPipeline pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["sdxl-t2iadapter"] = StableDiffusionXLAdapterPipeline pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["sdxl-t2iadapter"] = StableDiffusionXLAdapterPipeline pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["sdxl-t2iadapter"] = StableDiffusionXLAdapterPipeline """ if pipeline.__class__.__name__ == 'StableDiffusionAdapterPipeline' or pipeline.__class__.__name__ == 'StableDiffusionXLAdapterPipeline': pass # already initialized if detect.is_sdxl(pipeline): self.pipeline = StableDiffusionXLAdapterPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), adapter=adapter, ) sd_models.move_model(self.pipeline, pipeline.device) sd_models.apply_balanced_offload(self.pipeline, force=True) elif detect.is_sd15(pipeline): self.pipeline = StableDiffusionAdapterPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), requires_safety_checker=False, safety_checker=None, adapter=adapter, ) sd_models.move_model(self.pipeline, pipeline.device) sd_models.apply_balanced_offload(self.pipeline, force=True) else: log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') return if dtype is not None and self.pipeline is not None: self.pipeline.dtype = dtype t1 = time.time() if self.pipeline is not None: log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') else: log.error(f'Control {what} pipeline: not initialized') def restore(self): self.pipeline = None return self.orig_pipeline ================================================ FILE: modules/control/units/xs.py ================================================ import os import time from typing import Union import threading from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from modules.shared import log, opts, listdir from modules import errors, sd_models from modules.control.units.xs_model import ControlNetXSModel from modules.control.units.xs_pipe import StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline from modules.control.units import detect what = 'ControlNet-XS' debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: CONTROL') predefined_sd15 = { } predefined_sdxl = { 'Canny': 'UmerHA/ConrolNetXS-SDXL-canny', 'Depth': 'UmerHA/ConrolNetXS-SDXL-depth', } models = {} all_models = {} all_models.update(predefined_sd15) all_models.update(predefined_sdxl) cache_dir = 'models/control/xs' load_lock = threading.Lock() def find_models(): path = os.path.join(opts.control_dir, 'xs') files = listdir(path) files = [f for f in files if f.endswith('.safetensors')] downloaded_models = {} for f in files: basename = os.path.splitext(os.path.relpath(f, path))[0] downloaded_models[basename] = os.path.join(path, f) all_models.update(downloaded_models) return downloaded_models def list_models(refresh=False): global models # pylint: disable=global-statement import modules.shared if not refresh and len(models) > 0: return models models = {} if modules.shared.sd_model_type == 'none': models = ['None'] elif modules.shared.sd_model_type == 'sdxl': models = ['None'] + sorted(predefined_sdxl) + sorted(find_models()) elif modules.shared.sd_model_type == 'sd': models = ['None'] + sorted(predefined_sd15) + sorted(find_models()) else: log.error(f'Control {what} model list failed: unknown model type') models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models()) debug(f'Control list {what}: path={cache_dir} models={models}') return models class ControlNetXS(): def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None): self.model: ControlNetXSModel = None self.model_id: str = model_id self.device = device self.dtype = dtype self.load_config = { 'cache_dir': cache_dir, 'learn_embedding': True } if load_config is not None: self.load_config.update(load_config) if model_id is not None: self.load() def __str__(self): return f' ControlNetXS(id={self.model_id} model={self.model.__class__.__name__})' if self.model_id and self.model else '' def reset(self): if self.model is not None: debug(f'Control {what} model unloaded') self.model = None self.model_id = None def load(self, model_id: str = None, time_embedding_mix: float = 0.0, force: bool = True) -> str: with load_lock: try: t0 = time.time() model_id = model_id or self.model_id if model_id is None or model_id == 'None': self.reset() return if model_id not in all_models: log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}') return model_path = all_models[model_id] if model_path == '': return if model_path is None: log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id') return if model_id == self.model_id and not force: # log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded') return self.load_config['time_embedding_mix'] = time_embedding_mix if opts.offline_mode: self.load_config["local_files_only"] = True os.environ['HF_HUB_OFFLINE'] = '1' else: os.environ.pop('HF_HUB_OFFLINE', None) os.unsetenv('HF_HUB_OFFLINE') log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}') if model_path.endswith('.safetensors'): self.model = ControlNetXSModel.from_single_file(model_path, **self.load_config) else: self.model = ControlNetXSModel.from_pretrained(model_path, **self.load_config) if self.device is not None: self.model.to(self.device) if self.dtype is not None: self.model.to(self.dtype) t1 = time.time() self.model_id = model_id log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}') return f'{what} loaded model: {model_id}' except Exception as e: log.error(f'Control {what} model load failed: id="{model_id}" error={e}') errors.display(e, f'Control {what} load') return f'{what} failed to load model: {model_id}' class ControlNetXSPipeline(): def __init__(self, controlnet: Union[ControlNetXSModel, list[ControlNetXSModel]], pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None): t0 = time.time() self.orig_pipeline = pipeline self.pipeline = None if pipeline is None: log.error(f'Control {what} pipeline: model not loaded') return if detect.is_sdxl(pipeline): self.pipeline = StableDiffusionXLControlNetXSPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, text_encoder_2=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer, tokenizer_2=pipeline.tokenizer_2, unet=pipeline.unet, scheduler=pipeline.scheduler, # feature_extractor=getattr(pipeline, 'feature_extractor', None), controlnet=controlnet, # can be a list ) sd_models.move_model(self.pipeline, pipeline.device) sd_models.apply_balanced_offload(self.pipeline, force=True) elif detect.is_sd15(pipeline): self.pipeline = StableDiffusionControlNetXSPipeline( vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, unet=pipeline.unet, scheduler=pipeline.scheduler, feature_extractor=getattr(pipeline, 'feature_extractor', None), requires_safety_checker=False, safety_checker=None, controlnet=controlnet, # can be a list ) sd_models.move_model(self.pipeline, pipeline.device) sd_models.apply_balanced_offload(self.pipeline, force=True) else: log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type') return if dtype is not None and self.pipeline is not None: self.pipeline = self.pipeline.to(dtype) t1 = time.time() if self.pipeline is not None: log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') else: log.error(f'Control {what} pipeline: not initialized') def restore(self): self.pipeline = None return self.orig_pipeline ================================================ FILE: modules/control/units/xs_model.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import functional as F from torch.nn.modules.normalization import GroupNorm from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.lora import LoRACompatibleConv from diffusers.models.modeling_utils import ModelMixin try: from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, Downsample2D, ResnetBlock2D, Transformer2DModel, UpBlock2D, Upsample2D # pylint: disable=no-name-in-module except Exception: pass try: from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, Downsample2D, ResnetBlock2D, Transformer2DModel, UpBlock2D, Upsample2D except Exception: pass try: from diffusers.models.unet_2d_condition import UNet2DConditionModel except Exception: from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.utils import BaseOutput, logging, USE_PEFT_BACKEND logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetXSOutput(BaseOutput): """ The output of [`ControlNetXSModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model output, but is already the final output. """ sample: torch.FloatTensor = None # copied from diffusers.models.controlnet.ControlNetConditioningEmbedding class ControlNetConditioningEmbedding(nn.Module): """ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into feature maps ..." """ def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ControlNetXSModel(ModelMixin, ConfigMixin): r""" A ControlNet-XS model This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation of [`UNet2DConditionModel`] for them. Parameters: conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `controlnet_cond_embedding` layer. time_embedding_input_dim (`int`, defaults to 320): Dimension of input into time embedding. Needs to be same as in the base model. time_embedding_dim (`int`, defaults to 1280): Dimension of output from time embedding. Needs to be same as in the base model. learn_embedding (`bool`, defaults to `False`): Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. time_embedding_mix (`float`, defaults to 1.0): Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used. base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`): Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it. """ @classmethod def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True): """ Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS). Parameters: base_model (`UNet2DConditionModel`): Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL. is_sdxl (`bool`, defaults to `True`): Whether passed `base_model` is a StableDiffusion-XL model. """ def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int): """ Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why). The original ControlNet-XS model, however, define the number of attention heads. That's why compute the dimensions needed to get the correct number of attention heads. """ block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels] dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels] return dim_attn_heads if is_sdxl: return ControlNetXSModel.from_unet( base_model, time_embedding_mix=0.95, learn_embedding=True, size_ratio=0.1, conditioning_embedding_out_channels=(16, 32, 96, 256), num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64), ) else: return ControlNetXSModel.from_unet( base_model, time_embedding_mix=1.0, learn_embedding=True, size_ratio=0.0125, conditioning_embedding_out_channels=(16, 32, 96, 256), num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8), ) @classmethod def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str): """To create correctly sized connections between base and control model, we need to know the input and output channels of each subblock. Parameters: unet (`UNet2DConditionModel`): Unet of which the subblock channels sizes are to be gathered. base_or_control (`str`): Needs to be either "base" or "control". If "base", decoder is also considered. """ if base_or_control not in ["base", "control"]: raise ValueError("`base_or_control` needs to be either `base` or `control`") channel_sizes = {"down": [], "mid": [], "up": []} # input convolution channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels)) # encoder blocks for module in unet.down_blocks: if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): for r in module.resnets: channel_sizes["down"].append((r.in_channels, r.out_channels)) if module.downsamplers: channel_sizes["down"].append( (module.downsamplers[0].channels, module.downsamplers[0].out_channels) ) else: raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.") # middle block channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels)) # decoder blocks if base_or_control == "base": for module in unet.up_blocks: if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)): for r in module.resnets: channel_sizes["up"].append((r.in_channels, r.out_channels)) else: raise ValueError( f"Encountered unknown module of type {type(module)} while creating ControlNet-XS." ) return channel_sizes @register_to_config def __init__( self, conditioning_channels: int = 3, conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), controlnet_conditioning_channel_order: str = "rgb", # pylint: disable=unused-argument time_embedding_input_dim: int = 320, time_embedding_dim: int = 1280, time_embedding_mix: float = 1.0, # pylint: disable=unused-argument learn_embedding: bool = False, # pylint: disable=unused-argument base_model_channel_sizes: Dict[str, List[Tuple[int]]] = { "down": [ (4, 320), (320, 320), (320, 320), (320, 320), (320, 640), (640, 640), (640, 640), (640, 1280), (1280, 1280), ], "mid": [(1280, 1280)], "up": [ (2560, 1280), (2560, 1280), (1920, 1280), (1920, 640), (1280, 640), (960, 640), (960, 320), (640, 320), (640, 320), ], }, sample_size: Optional[int] = None, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), norm_num_groups: Optional[int] = 32, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, upcast_attention: bool = False, ): super().__init__() # 1 - Create control unet self.control_model = UNet2DConditionModel( sample_size=sample_size, down_block_types=down_block_types, up_block_types=up_block_types, block_out_channels=block_out_channels, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, transformer_layers_per_block=transformer_layers_per_block, attention_head_dim=num_attention_heads, use_linear_projection=True, upcast_attention=upcast_attention, time_embedding_dim=time_embedding_dim, ) # 2 - Do model surgery on control model # 2.1 - Allow to use the same time information as the base model adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim) # 2.2 - Allow for information infusion from base model # We concat the output of each base encoder subblocks to the input of the next control encoder subblock # (We ignore the 1st element, as it represents the `conv_in`.) extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]] it_extra_input_channels = iter(extra_input_channels) for b, block in enumerate(self.control_model.down_blocks): for r in range(len(block.resnets)): increase_block_input_in_encoder_resnet( self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels) ) if block.downsamplers: increase_block_input_in_encoder_downsampler( self.control_model, block_no=b, by=next(it_extra_input_channels) ) increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1]) # 2.3 - Make group norms work with modified channel sizes adjust_group_norms(self.control_model) # 3 - Gather Channel Sizes self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control") self.ch_inout_base = base_model_channel_sizes # 4 - Build connections between base and control model self.down_zero_convs_out = nn.ModuleList([]) self.down_zero_convs_in = nn.ModuleList([]) self.middle_block_out = nn.ModuleList([]) self.middle_block_in = nn.ModuleList([]) self.up_zero_convs_out = nn.ModuleList([]) self.up_zero_convs_in = nn.ModuleList([]) for ch_io_base in self.ch_inout_base["down"]: self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1])) for i in range(len(self.ch_inout_ctrl["down"])): self.down_zero_convs_out.append( self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1]) ) self.middle_block_out = self._make_zero_conv( self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1] ) self.up_zero_convs_out.append( self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1]) ) for i in range(1, len(self.ch_inout_ctrl["down"])): self.up_zero_convs_out.append( self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1]) ) # 5 - Create conditioning hint embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) # In the mininal implementation setting, we only need the control model up to the mid block del self.control_model.up_blocks del self.control_model.conv_norm_out del self.control_model.conv_out @classmethod def from_unet( cls, unet: UNet2DConditionModel, conditioning_channels: int = 3, conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), controlnet_conditioning_channel_order: str = "rgb", learn_embedding: bool = False, time_embedding_mix: float = 1.0, block_out_channels: Optional[Tuple[int]] = None, size_ratio: Optional[float] = None, num_attention_heads: Optional[Union[int, Tuple[int]]] = 8, norm_num_groups: Optional[int] = None, ): r""" Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it. conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `controlnet_cond_embedding` layer. controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. learn_embedding (`bool`, defaults to `False`): Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`. time_embedding_mix (`float`, defaults to 1.0): Linear interpolation parameter used if `learn_embedding` is `True`. block_out_channels (`Tuple[int]`, *optional*): Down blocks output channels in control model. Either this or `size_ratio` must be given. size_ratio (float, *optional*): When given, block_out_channels is set to a relative fraction of the base model's block_out_channels. Either this or `block_out_channels` must be given. num_attention_heads (`Union[int, Tuple[int]]`, *optional*): The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. norm_num_groups (int, *optional*, defaults to `None`): The number of groups to use for the normalization of the control unet. If `None`, `int(unet.config.norm_num_groups * size_ratio)` is taken. """ # Check input fixed_size = block_out_channels is not None relative_size = size_ratio is not None if not fixed_size ^ relative_size: raise ValueError("Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing).") # Create model if block_out_channels is None: block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels] # Check that attention heads and group norms match channel sizes # - attention heads def attn_heads_match_channel_sizes(attn_heads, channel_sizes): if isinstance(attn_heads, (tuple, list)): return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes)) else: return all(c % attn_heads == 0 for c in channel_sizes) num_attention_heads = num_attention_heads or unet.config.attention_head_dim if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels): raise ValueError( f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually." ) # - group norms def group_norms_match_channel_sizes(num_groups, channel_sizes): return all(c % num_groups == 0 for c in channel_sizes) if norm_num_groups is None: if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels): norm_num_groups = unet.config.norm_num_groups else: norm_num_groups = min(block_out_channels) if not group_norms_match_channel_sizes(norm_num_groups, block_out_channels): raise ValueError(f'ControlNetXSModel mismatch: block_out_channels={block_out_channels} norm_num_groups={unet.config.norm_num_groups}') def get_time_emb_input_dim(unet: UNet2DConditionModel): return unet.time_embedding.linear_1.in_features def get_time_emb_dim(unet: UNet2DConditionModel): return unet.time_embedding.linear_2.out_features # Clone params from base unet if # (i) it's required to build SD or SDXL, and # (ii) it's not used for the time embedding (as time embedding of control model is never used), and # (iii) it's not set further below anyway to_keep = [ "cross_attention_dim", "down_block_types", "sample_size", "transformer_layers_per_block", "up_block_types", "upcast_attention", ] kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep} kwargs.update(block_out_channels=block_out_channels) kwargs.update(num_attention_heads=num_attention_heads) kwargs.update(norm_num_groups=norm_num_groups) # Add controlnetxs-specific params kwargs.update( conditioning_channels=conditioning_channels, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, time_embedding_input_dim=get_time_emb_input_dim(unet), time_embedding_dim=get_time_emb_dim(unet), time_embedding_mix=time_embedding_mix, learn_embedding=learn_embedding, base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"), conditioning_embedding_out_channels=conditioning_embedding_out_channels, ) return cls(**kwargs) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ return self.control_model.attn_processors def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False ): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ self.control_model.set_attn_processor(processor, _remove_lora) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.control_model.set_default_attn_processor() def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ self.control_model.set_attention_slice(slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (UNet2DConditionModel)): if value: module.enable_gradient_checkpointing() else: module.disable_gradient_checkpointing() def forward( self, base_model: UNet2DConditionModel, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, return_dict: bool = True, ) -> Union[ControlNetXSOutput, Tuple]: """ The [`ControlNetModel`] forward method. Args: base_model (`UNet2DConditionModel`): The base unet model we want to control. sample (`torch.FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.FloatTensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): How much the control model affects the base model outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order # pylint: disable=no-member if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # scale control strength n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out) scale_list = torch.full((n_connections,), conditioning_scale) # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = base_model.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) if self.config.learn_embedding: # pylint: disable=no-member base_model = base_model.to(self.control_model.device) ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond) base_temb = base_model.time_embedding(t_emb, timestep_cond) interpolation_param = self.config.time_embedding_mix**0.3 # pylint: disable=no-member temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) else: temb = base_model.time_embedding(t_emb) # added time & text embeddings aug_emb = None if base_model.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if base_model.config.class_embed_type == "timestep": class_labels = base_model.time_proj(class_labels) class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype) temb = temb + class_emb if base_model.config.addition_embed_type is not None: if base_model.config.addition_embed_type == "text": aug_emb = base_model.add_embedding(encoder_hidden_states) elif base_model.config.addition_embed_type == "text_image": raise NotImplementedError() elif base_model.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = base_model.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(temb.dtype) aug_emb = base_model.add_embedding(add_embeds) elif base_model.config.addition_embed_type == "image": raise NotImplementedError() elif base_model.config.addition_embed_type == "image_hint": raise NotImplementedError() temb = temb + aug_emb if aug_emb is not None else temb # text embeddings cemb = encoder_hidden_states # Preparation guided_hint = self.controlnet_cond_embedding(controlnet_cond) h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] it_down_convs_in, it_down_convs_out, _it_dec_convs_in, it_up_convs_out = map( iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out) ) scales = iter(scale_list) base_down_subblocks = to_sub_blocks(base_model.down_blocks) ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks) base_mid_subblocks = to_sub_blocks([base_model.mid_block]) ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block]) base_up_subblocks = to_sub_blocks(base_model.up_blocks) # Cross Control # 0 - conv in h_base = base_model.conv_in(h_base) h_ctrl = self.control_model.conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) # 1 - down for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks): h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) # 2 - mid h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks): h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base # 3 - up for _i, m_base in enumerate(base_up_subblocks): h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) h_base = base_model.conv_norm_out(h_base) h_base = base_model.conv_act(h_base) h_base = base_model.conv_out(h_base) if not return_dict: return h_base return ControlNetXSOutput(sample=h_base) def _make_zero_conv(self, in_channels, out_channels=None): # keep running track of channels sizes self.in_channels = in_channels self.out_channels = out_channels or in_channels return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) @torch.no_grad() def _check_if_vae_compatible(self, vae: AutoencoderKL): condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1) # pylint: disable=no-member vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1) compatible = condition_downscale_factor == vae_downscale_factor return compatible, condition_downscale_factor, vae_downscale_factor class SubBlock(nn.ModuleList): """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively. Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base. """ def __init__(self, ms, *args, **kwargs): if not is_iterable(ms): ms = [ms] super().__init__(ms, *args, **kwargs) def forward( self, x: torch.Tensor, temb: torch.Tensor, cemb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """Iterate through children and pass correct information to each.""" for m in self: if isinstance(m, ResnetBlock2D): x = m(x, temb) elif isinstance(m, Transformer2DModel): x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample elif isinstance(m, Downsample2D): x = m(x) elif isinstance(m, Upsample2D): x = m(x) else: raise ValueError( f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`" ) return x def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int): unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim) def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by): """Increase channels sizes to allow for additional concatted information from base model""" r = unet.down_blocks[block_no].resnets[resnet_idx] old_norm1, old_conv1 = r.norm1, r.conv1 # norm norm_args = "num_groups num_channels eps affine".split(" ") for a in norm_args: assert hasattr(old_norm1, a) norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} norm_kwargs["num_channels"] += by # surgery done here # conv1 conv1_args = [ "in_channels", "out_channels", "kernel_size", "stride", "padding", "dilation", "groups", "bias", "padding_mode", ] if not USE_PEFT_BACKEND: conv1_args.append("lora_layer") for a in conv1_args: assert hasattr(old_conv1, a) conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. conv1_kwargs["in_channels"] += by # surgery done here # conv_shortcut # as we changed the input size of the block, the input and output sizes are likely different, # therefore we need a conv_shortcut (simply adding won't work) conv_shortcut_args_kwargs = { "in_channels": conv1_kwargs["in_channels"], "out_channels": conv1_kwargs["out_channels"], # default arguments from resnet.__init__ "kernel_size": 1, "stride": 1, "padding": 0, "bias": True, } # swap old with new modules unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs) unet.down_blocks[block_no].resnets[resnet_idx].conv1 = ( nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) ) unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = ( nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) ) unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by): """Increase channels sizes to allow for additional concatted information from base model""" old_down = unet.down_blocks[block_no].downsamplers[0].conv args = [ "in_channels", "out_channels", "kernel_size", "stride", "padding", "dilation", "groups", "bias", "padding_mode", ] if not USE_PEFT_BACKEND: args.append("lora_layer") for a in args: assert hasattr(old_down, a) kwargs = {a: getattr(old_down, a) for a in args} kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor. kwargs["in_channels"] += by # surgery done here # swap old with new modules unet.down_blocks[block_no].downsamplers[0].conv = ( nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs) ) unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by): """Increase channels sizes to allow for additional concatted information from base model""" m = unet.mid_block.resnets[0] old_norm1, old_conv1 = m.norm1, m.conv1 # norm norm_args = "num_groups num_channels eps affine".split(" ") for a in norm_args: assert hasattr(old_norm1, a) norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args} norm_kwargs["num_channels"] += by # surgery done here conv1_args = [ "in_channels", "out_channels", "kernel_size", "stride", "padding", "dilation", "groups", "bias", "padding_mode", ] if not USE_PEFT_BACKEND: conv1_args.append("lora_layer") conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args} conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor. conv1_kwargs["in_channels"] += by # surgery done here # conv_shortcut # as we changed the input size of the block, the input and output sizes are likely different, # therefore we need a conv_shortcut (simply adding won't work) conv_shortcut_args_kwargs = { "in_channels": conv1_kwargs["in_channels"], "out_channels": conv1_kwargs["out_channels"], # default arguments from resnet.__init__ "kernel_size": 1, "stride": 1, "padding": 0, "bias": True, } # swap old with new modules unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs) unet.mid_block.resnets[0].conv1 = ( nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs) ) unet.mid_block.resnets[0].conv_shortcut = ( nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs) ) unet.mid_block.resnets[0].in_channels += by # surgery done here def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32): def find_denominator(number, start): if start >= number: return number while start != 0: residual = number % start if residual == 0: return start start -= 1 for block in [*unet.down_blocks, unet.mid_block]: # resnets for r in block.resnets: if r.norm1.num_groups < max_num_group: r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group) if r.norm2.num_groups < max_num_group: r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group) # transformers if hasattr(block, "attentions"): for a in block.attentions: if a.norm.num_groups < max_num_group: a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group) def is_iterable(o): if isinstance(o, str): return False try: iter(o) return True except TypeError: return False def to_sub_blocks(blocks): if not is_iterable(blocks): blocks = [blocks] sub_blocks = [] for b in blocks: if hasattr(b, "resnets"): if hasattr(b, "attentions") and b.attentions is not None: for r, a in zip(b.resnets, b.attentions): sub_blocks.append([r, a]) num_resnets = len(b.resnets) num_attns = len(b.attentions) if num_resnets > num_attns: # we can have more resnets than attentions, so add each resnet as separate subblock for i in range(num_attns, num_resnets): sub_blocks.append([b.resnets[i]]) else: for r in b.resnets: sub_blocks.append([r]) # upsamplers are part of the same subblock if hasattr(b, "upsamplers") and b.upsamplers is not None: for u in b.upsamplers: sub_blocks[-1].extend([u]) # downsamplers are own subblock if hasattr(b, "downsamplers") and b.downsamplers is not None: for d in b.downsamplers: sub_blocks.append([d]) return list(map(SubBlock, sub_blocks)) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: modules/control/units/xs_pipe.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from PIL import Image import torch import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPImageProcessor from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.import_utils import is_invisible_watermark_available from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from modules.control.units.xs_model import ControlNetXSModel if is_invisible_watermark_available(): from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionXLControlNetXSPipeline( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): Second frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. tokenizer_2 ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetXSModel`]: Provides additional conditioning to the `unet` during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings should always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: ControlNetXSModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() if isinstance(controlnet, list): if len(controlnet) == 1: controlnet = controlnet[0] else: raise ValueError( "ControlNetXS pipeline only supports a single ControlNetXS model" ) vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( vae ) if not vae_compatible: raise ValueError( f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." ) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, image, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetXSModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetXSModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetXSModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: assert False start, end = control_guidance_start, control_guidance_end if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stages where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ if not hasattr(self, "unet"): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled text embeddings are generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. control_guidance_start (`float`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, image, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # 4. Prepare image if isinstance(controlnet, ControlNetXSModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Prepare added time ids & embeddings if isinstance(image, list): original_size = original_size or image[0].shape[-2:] else: original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # predict the noise residual dont_control = ( i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end ) if dont_control: noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=True, ).sample else: noise_pred = self.controlnet( base_model=self.unet, sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=True, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # manually for max memory savings if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if output_type != "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if output_type != "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) class StableDiffusionControlNetXSPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetXSModel`]): Provides additional conditioning to the `unet` during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae>controlnet" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: ControlNetXSModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible( vae ) if not vae_compatible: raise ValueError( f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, **kwargs, ): prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetXSModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetXSModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetXSModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: assert False start, end = control_guidance_start, control_guidance_end if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stages where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ if not hasattr(self, "unet"): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetXSModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual dont_control = ( i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end ) if dont_control: noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=True, ).sample else: noise_pred = self.controlnet( base_model=self.unet, sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, return_dict=True, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if output_type != "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: modules/control/util.py ================================================ import os import sys import random import cv2 import numpy as np import torch annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') def dict2str(d: dict): arr = [f'{name} {d[name]}' for i, name in enumerate(d) if d[name] is not None and d[name] != ''] return ' | '.join(arr) def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: x = x[:, :, None] assert x.ndim == 3 _H, _W, C = x.shape assert C == 1 or C == 3 or C == 4 if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) if C == 4: color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y return x # should not happen def make_noise_disk(H, W, C, F): noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_LANCZOS4) noise = noise[F: F + H, F: F + W] noise -= np.min(noise) noise /= np.max(noise) if C == 1: noise = noise[:, :, None] return noise def nms(x, t, s): x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) y = np.zeros_like(x) for f in [f1, f2, f3, f4]: np.putmask(y, cv2.dilate(x, kernel=f) == x, x) z = np.zeros_like(y, dtype=np.uint8) z[y > t] = 255 # pylint: disable=unsupported-assignment-operation return z def min_max_norm(x): x -= np.min(x) x /= np.maximum(np.max(x), 1e-5) return x def safe_step(x, step=2): y = x.astype(np.float32) * float(step + 1) y = y.astype(np.int32).astype(np.float32) / float(step) return y def img2mask(img, H, W, low=10, high=90): assert img.ndim == 3 or img.ndim == 2 assert img.dtype == np.uint8 if img.ndim == 3: y = img[:, :, random.randrange(0, img.shape[2])] else: y = img y = cv2.resize(y, (W, H), interpolation=cv2.INTER_LANCZOS4) if random.uniform(0, 1) < 0.5: y = 255 - y return y < np.percentile(y, random.randrange(low, high)) def resize_image(input_image, resolution): H, W, _C = input_image.shape H = float(H) W = float(W) k = float(resolution) / min(H, W) H *= k W *= k H = int(np.round(H / 64.0)) * 64 W = int(np.round(W / 64.0)) * 64 img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4) return img def torch_gc(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def ade_palette(): """ADE20K palette that maps each class to RGB values.""" return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] def blend(images): if images is None or len(images) == 0: return images y = np.zeros((images[0].shape[0], images[0].shape[1], 3), dtype=np.float32) for img in images: if img.shape[0] != y.shape[0] or img.shape[1] != y.shape[1]: img = cv2.resize(img, (y.shape[1], y.shape[0]), interpolation=cv2.INTER_LANCZOS4) if len(img.shape) == 3 and img.shape[2] == 4: # rgba to rgb img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) if len(img.shape) == 2: # grayscale to rgb img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) y = cv2.add(y, img.astype(np.float32)) y = y.clip(0, 255).astype(np.uint8) return y def decode_fourcc(cc): cc_bytes = int(cc).to_bytes(4, byteorder=sys.byteorder) # convert code to a bytearray cc_str = cc_bytes.decode() # decode byteaarray to a string return cc_str ================================================ FILE: modules/detailer.py ================================================ from abc import abstractmethod from modules import shared class Detailer: # abstract class used for postprocessing def name(self): return "None" @abstractmethod def restore(self, np_image): return np_image def detail(np_image, p=None): # postprocesses the image detailers = [x for x in shared.detailers if x.name() == shared.opts.detailer_model or shared.opts.detailer_model is None] if len(detailers) == 0: return np_image detailer: Detailer = detailers[0] return detailer.restore(np_image, p) ================================================ FILE: modules/devices.py ================================================ import os import sys import time import contextlib import torch from modules import rocm, attention from modules.errors import log, display, install as install_traceback debug = os.environ.get('SD_DEVICE_DEBUG', None) is not None install_traceback() # traceback handler opts = None # initialized in get_backend to avoid circular import args = None # initialized in get_backend to avoid circular import cuda_ok = torch.cuda.is_available() or (hasattr(torch, 'xpu') and torch.xpu.is_available()) inference_context = torch.no_grad cpu = torch.device("cpu") fp16_ok = None # set once by test_fp16 bf16_ok = None # set once by test_bf16 triton_ok = None # set once by test_triton backend = None # set by get_backend device = None # set by get_optimal_device dtype = None # set by set_dtype dtype_vae = None dtype_unet = None unet_needs_upcast = False # compatibility item onnx = None sdpa_original = None sdpa_pre_dyanmic_atten = None previous_oom = 0 # oom counter if debug: log.info(f'Torch build config: {torch.__config__.show()}') # set_cuda_sync_mode('block') # none/auto/spin/yield/block def has_mps() -> bool: if sys.platform != "darwin": return False else: from modules import devices_mac # pylint: disable=ungrouped-imports return devices_mac.has_mps # pylint: disable=used-before-assignment def has_xpu() -> bool: return bool(hasattr(torch, 'xpu') and torch.xpu.is_available()) def has_rocm() -> bool: return bool(torch.version.hip is not None and torch.cuda.is_available()) def has_zluda() -> bool: if not cuda_ok: return False try: dev = torch.device("cuda") cc = torch.cuda.get_device_capability(dev) return cc == (8, 8) except Exception: return False def has_triton(early:bool=False) -> bool: if triton_ok is not None: return triton_ok return test_triton(early=early) def get_hip_agent() -> rocm.Agent: return rocm.Agent(device) def get_backend(shared_cmd_opts): global args # pylint: disable=global-statement args = shared_cmd_opts if args.use_openvino: name = 'openvino' elif args.use_directml: name = 'directml' elif has_xpu(): name = 'ipex' elif has_zluda(): name = 'zluda' elif torch.cuda.is_available() and torch.version.cuda: name = 'cuda' elif torch.cuda.is_available() and torch.version.hip: name = 'rocm' elif sys.platform == 'darwin': name = 'mps' else: name = 'cpu' return name def get_gpu_info(): def get_driver(): if torch.xpu.is_available(): try: return torch.xpu.get_device_properties(torch.xpu.current_device()).driver_version except Exception: return '' elif torch.cuda.is_available() and torch.version.cuda: try: import subprocess result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE) version = result.stdout.decode(encoding="utf8", errors="ignore").strip() return version except Exception: return '' else: return '' def get_package_version(pkg: str): import pkg_resources spec = pkg_resources.working_set.by_key.get(pkg, None) # more reliable than importlib version = pkg_resources.get_distribution(pkg).version if spec is not None else None return version if not torch.cuda.is_available(): try: if backend == 'openvino': from modules.intel.openvino import get_openvino_device return { 'device': get_openvino_device(), # pylint: disable=used-before-assignment 'openvino': get_package_version("openvino"), } elif backend == 'directml': return { 'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()}', 'directml': get_package_version("torch-directml"), } else: return {} except Exception: return {} else: try: if backend == 'ipex': return { 'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} n={torch.xpu.device_count()}', 'ipex': get_package_version('intel-extension-for-pytorch'), 'driver': get_driver(), } elif backend == 'cuda' or backend == 'zluda': return { 'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()} arch={torch.cuda.get_arch_list()[-1]} capability={torch.cuda.get_device_capability(device)}', 'cuda': torch.version.cuda, 'cudnn': torch.backends.cudnn.version(), 'driver': get_driver(), } elif backend == 'rocm': return { 'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} n={torch.cuda.device_count()}', 'hip': torch.version.hip, } else: return { 'device': 'unknown' } except Exception as ex: if debug: display(ex, 'Device exception') return { 'error': ex } def get_cuda_device_string(): from modules.shared import cmd_opts if backend == 'ipex': if cmd_opts.device_id is not None: return f"xpu:{cmd_opts.device_id}" return "xpu" elif backend == 'directml' and torch.dml.is_available(): if cmd_opts.device_id is not None: return f"privateuseone:{cmd_opts.device_id}" return torch.dml.get_device_string(torch.dml.default_device().index) else: if cmd_opts.device_id is not None: return f"cuda:{cmd_opts.device_id}" return "cuda" def get_optimal_device_name(): if backend == 'openvino': return "cpu" if cuda_ok or backend == 'directml': return get_cuda_device_string() if has_mps() and backend != 'openvino': return "mps" return "cpu" def get_optimal_device(): return torch.device(get_optimal_device_name()) def torch_gc(force:bool=False, fast:bool=False, reason:str=None): def get_stats(): mem_dict = memstats.memory_stats() gpu_dict = mem_dict.get('gpu', {}) ram_dict = mem_dict.get('ram', {}) oom = gpu_dict.get('oom', 0) ram = ram_dict.get('used', 0) if backend == "directml": gpu = torch.cuda.memory_allocated() / (1 << 30) else: gpu = gpu_dict.get('used', 0) used_gpu = round(100 * gpu / gpu_dict.get('total', 1)) if gpu_dict.get('total', 1) > 1 else 0 used_ram = round(100 * ram / ram_dict.get('total', 1)) if ram_dict.get('total', 1) > 1 else 0 return gpu, used_gpu, ram, used_ram, oom global previous_oom # pylint: disable=global-statement import gc from modules import timer, memstats from modules.shared import cmd_opts t0 = time.time() gpu, used_gpu, ram, _used_ram, oom = get_stats() threshold = 0 if (cmd_opts.lowvram and not cmd_opts.use_zluda) else opts.torch_gc_threshold collected = 0 if reason is None and force: reason='force' if threshold == 0 or used_gpu >= threshold: force = True if reason is None: reason = 'threshold' if oom > previous_oom: previous_oom = oom log.warning(f'Torch GPU out-of-memory error: {memstats.memory_stats()}') force = True if reason is None: reason = 'oom' if debug: fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access log.trace(f'GC: run={force} fast={fast} used={used_gpu} threshold={threshold} fn={fn}') if force: # actual gc collected = gc.collect() if not fast else 0 # python gc if cuda_ok: try: with torch.cuda.device(get_cuda_device_string()): torch.cuda.synchronize() torch.cuda.empty_cache() # cuda gc torch.cuda.ipc_collect() except Exception as e: log.error(f'GC: {e}') else: return gpu, ram t1 = time.time() timer.process.add('gc', t1 - t0) if fast: return gpu, ram new_gpu, new_used_gpu, new_ram, new_used_ram, oom = get_stats() before = { 'gpu': gpu, 'ram': ram } after = { 'gpu': new_gpu, 'ram': new_ram, 'oom': oom } utilization = { 'gpu': new_used_gpu, 'ram': new_used_ram } results = { 'gpu': round(gpu - new_gpu, 2), 'py': collected } fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access log.debug(f'GC: current={after} prev={before} load={utilization} gc={results} fn={fn} why={reason} time={t1-t0:.2f}') return new_gpu, new_ram def set_cuda_sync_mode(mode): """ Set the CUDA device synchronization mode: auto, spin, yield or block. auto: Chooses spin or yield depending on the number of available CPU cores. spin: Runs one CPU core per GPU at 100% to poll for completed operations. yield: Gives control to other threads between polling, if any are waiting. block: Lets the thread sleep until the GPU driver signals completion. """ if mode == -1 or mode == 'none' or not cuda_ok: return try: import ctypes log.info(f'Torch CUDA sync: mode={mode}') torch.cuda.set_device(torch.device(get_optimal_device_name())) ctypes.CDLL('libcudart.so').cudaSetDeviceFlags({'auto': 0, 'spin': 1, 'yield': 2, 'block': 4}[mode]) except Exception: pass def set_cuda_memory_limit(): if not cuda_ok or opts.cuda_mem_fraction == 0: return try: from modules.shared import cmd_opts torch_gc(force=True, reason='cuda') mem = torch.cuda.get_device_properties(device).total_memory torch.cuda.set_per_process_memory_fraction(float(opts.cuda_mem_fraction), cmd_opts.device_id if cmd_opts.device_id is not None else 0) log.info(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}') except Exception as e: log.warning(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}') def set_cuda_tunable(): if not cuda_ok: return try: if opts.torch_tunable_ops != 'default': torch.cuda.tunable.enable(opts.torch_tunable_ops == 'true') torch.cuda.tunable.tuning_enable(opts.torch_tunable_ops == 'true') torch.cuda.tunable.set_max_tuning_duration(1000) # set to high value as actual is min(duration, iterations) torch.cuda.tunable.set_max_tuning_iterations(opts.torch_tunable_limit) fn = os.path.join(opts.tunable_dir, 'tunable.csv') lines={0} try: if os.path.exists(fn): with open(fn, 'r', encoding='utf8') as f: lines = sum(1 for _line in f) except Exception: pass torch.cuda.tunable.set_filename(fn) if torch.cuda.tunable.is_enabled(): log.debug(f'Torch tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}" entries={lines}') except Exception as e: log.warning(f'Torch tunable: {e}') def test_fp16(): global fp16_ok # pylint: disable=global-statement if fp16_ok is not None: return fp16_ok if opts.cuda_dtype != 'FP16': # don't override if the user sets it if sys.platform == "darwin" or backend in {'openvino', 'cpu'}: # override fp16_ok = False return fp16_ok elif backend == 'rocm': # gfx1102 (RX 7600, 7500, 7650 and 7700S) causes segfaults with fp16 # agent can be overriden to gfx1100 to get gfx1102 working with ROCm so check the gpu name as well agent = get_hip_agent() agent_name = getattr(torch.cuda.get_device_properties(device), "name", "AMD Radeon RX 0000") if agent.gfx_version == 0x1102 or (agent.gfx_version == 0x1100 and any(i in agent_name for i in ("7600", "7500", "7650", "7700S"))): fp16_ok = False return fp16_ok try: x = torch.tensor([[1.5,.0,.0,.0]]).to(device=device, dtype=torch.float16) layerNorm = torch.nn.LayerNorm(4, eps=0.00001, elementwise_affine=True, dtype=torch.float16, device=device) out = layerNorm(x) if out.dtype != torch.float16: raise RuntimeError('Torch FP16 test: dtype mismatch') if torch.all(torch.isnan(out)).item(): raise RuntimeError('Torch FP16 test: NaN') fp16_ok = True except Exception as ex: log.warning(f'Torch FP16 test fail: {ex}') fp16_ok = False return fp16_ok def test_bf16(): global bf16_ok # pylint: disable=global-statement if bf16_ok is not None: return bf16_ok if opts.cuda_dtype != 'BF16': # don't override if the user sets it if sys.platform == "darwin" or backend in {'openvino', 'directml', 'cpu'}: # override bf16_ok = False return bf16_ok elif backend == 'rocm' or backend == 'zluda': agent = None if backend == 'rocm': agent = get_hip_agent() else: from modules.zluda_installer import default_agent agent = default_agent if agent is not None and agent.gfx_version < 0x1100 and agent.arch != rocm.MicroArchitecture.CDNA: # all cards before RDNA 3 except for CDNA cards bf16_ok = False return bf16_ok try: import torch.nn.functional as F image = torch.randn(1, 4, 32, 32).to(device=device, dtype=torch.bfloat16) out = F.interpolate(image, size=(64, 64), mode="nearest") if out.dtype != torch.bfloat16: raise RuntimeError('Torch BF16 test: dtype mismatch') if torch.all(torch.isnan(out)).item(): raise RuntimeError('Torch BF16 test: NaN') bf16_ok = True except Exception as ex: log.warning(f'Torch BF16 test fail: {ex}') bf16_ok = False return bf16_ok def test_triton(early: bool = False): global triton_ok # pylint: disable=global-statement if triton_ok is not None and early: return triton_ok t0 = time.time() try: from torch.utils._triton import has_triton as torch_has_triton if torch_has_triton(): if early: return True def test_triton_func(a,b,c): return a * b + c test_triton_func = torch.compile(test_triton_func, fullgraph=True) test_triton_func(torch.randn(16, device=device), torch.randn(16, device=device), torch.randn(16, device=device)) triton_ok = True else: triton_ok = False except Exception as e: triton_ok = False line = str(e).splitlines()[0] log.warning(f"Triton test fail: {line}") if debug: from modules import errors errors.display(e, 'Triton') t1 = time.time() fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access log.debug(f'Triton: pass={triton_ok} fn={fn} time={t1-t0:.2f}') if not triton_ok and opts is not None: opts.sdnq_dequantize_compile = False return triton_ok def set_cudnn_params(): if not cuda_ok: return try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True except Exception as e: log.warning(f'Torch matmul: {e}') if torch.backends.cudnn.is_available(): try: if opts.cudnn_enabled != 'default': torch.backends.cudnn.enabled = opts.cudnn_enabled == 'true' log.debug(f'Torch cuDNN: enabled={torch.backends.cudnn.enabled}') torch.backends.cudnn.deterministic = opts.cudnn_deterministic torch.use_deterministic_algorithms(opts.cudnn_deterministic) if opts.cudnn_deterministic: os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8') log.debug(f'Torch cuDNN: deterministic={opts.cudnn_deterministic}') torch.backends.cudnn.benchmark = opts.cudnn_benchmark if opts.cudnn_benchmark: log.debug(f'Torch cuDNN: benchmark={opts.cudnn_benchmark}') torch.backends.cudnn.benchmark_limit = opts.cudnn_benchmark_limit torch.backends.cudnn.allow_tf32 = True except Exception as e: log.warning(f'Torch cuDNN: {e}') def override_ipex_math(): if backend == "ipex": try: if hasattr(torch.xpu, "set_fp32_math_mode"): # not available with pure torch+xpu, requires ipex torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32) torch.backends.mkldnn.allow_tf32 = True except Exception as e: log.warning(f'Torch ipex: {e}') def set_sdpa_params(): try: try: global sdpa_original # pylint: disable=global-statement if sdpa_original is not None: torch.nn.functional.scaled_dot_product_attention = sdpa_original else: sdpa_original = torch.nn.functional.scaled_dot_product_attention except Exception as err: log.warning(f'Torch attention: type="sdpa" {err}') try: torch.backends.cuda.enable_flash_sdp('Flash' in opts.sdp_options or 'Flash attention' in opts.sdp_options) torch.backends.cuda.enable_mem_efficient_sdp('Memory' in opts.sdp_options or 'Memory attention' in opts.sdp_options) torch.backends.cuda.enable_math_sdp('Math' in opts.sdp_options or 'Math attention' in opts.sdp_options) if hasattr(torch.backends.cuda, "allow_fp16_bf16_reduction_math_sdp"): # only valid for torch >= 2.5 torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) log.debug(f'Torch attention: type="sdpa" kernels={opts.sdp_options} overrides={opts.sdp_overrides}') except Exception as err: log.warning(f'Torch attention: type="sdpa" {err}') # Stack hijcaks in reverse order. This gives priority to the last added hijack. # If the last hijack is not compatible, it will use the one before it and so on. if 'Dynamic attention' in opts.sdp_overrides: global sdpa_pre_dyanmic_atten # pylint: disable=global-statement sdpa_pre_dyanmic_atten = attention.set_dynamic_attention() if 'Flex attention' in opts.sdp_overrides: attention.set_flex_attention() if 'Triton Flash attention' in opts.sdp_overrides: attention.set_triton_flash_attention(backend) if 'Flash attention' in opts.sdp_overrides: attention.set_ck_flash_attention(backend, device) if 'Sage attention' in opts.sdp_overrides: attention.set_sage_attention(backend, device) from importlib.metadata import version try: flash = version('flash-attn') except Exception: flash = False try: sage = version('sageattention') except Exception: sage = False log.debug(f'Torch attention installed: flashattn={flash} sageattention={sage}') from diffusers.models import attention_dispatch as a log.debug(f'Torch attention status: flash={a._CAN_USE_FLASH_ATTN} flash3={a._CAN_USE_FLASH_ATTN_3} aiter={a._CAN_USE_AITER_ATTN} sage={a._CAN_USE_SAGE_ATTN} flex={a._CAN_USE_FLEX_ATTN} npu={a._CAN_USE_NPU_ATTN} xla={a._CAN_USE_XLA_ATTN} xformers={a._CAN_USE_XFORMERS_ATTN}') # pylint: disable=protected-access except Exception as e: log.warning(f'Torch SDPA: {e}') def set_dtype(): global dtype, dtype_vae, dtype_unet, unet_needs_upcast, inference_context # pylint: disable=global-statement test_fp16() test_bf16() if opts.cuda_dtype == 'Auto': # detect if bf16_ok: dtype = torch.bfloat16 dtype_vae = torch.bfloat16 dtype_unet = torch.bfloat16 elif fp16_ok: dtype = torch.float16 dtype_vae = torch.float16 dtype_unet = torch.float16 else: dtype = torch.float32 dtype_vae = torch.float32 dtype_unet = torch.float32 elif opts.cuda_dtype == 'FP32': dtype = torch.float32 dtype_vae = torch.float32 dtype_unet = torch.float32 elif opts.cuda_dtype == 'BF16': if not bf16_ok: log.warning(f'Torch device capability failed: device={device} dtype={torch.bfloat16}') dtype = torch.bfloat16 dtype_vae = torch.bfloat16 dtype_unet = torch.bfloat16 elif opts.cuda_dtype == 'FP16': if not fp16_ok: log.warning(f'Torch device capability failed: device={device} dtype={torch.float16}') dtype = torch.float16 dtype_vae = torch.float16 dtype_unet = torch.float16 if opts.no_half: dtype = torch.float32 dtype_vae = torch.float32 dtype_unet = torch.float32 log.info(f'Torch override: no-half dtype={dtype}') if opts.no_half_vae: dtype_vae = torch.float32 log.info(f'Torch override: no-half-vae dtype={dtype_vae}') unet_needs_upcast = opts.upcast_sampling if opts.inference_mode == 'inference-mode': inference_context = torch.inference_mode elif opts.inference_mode == 'none': inference_context = contextlib.nullcontext else: inference_context = torch.no_grad def set_cuda_params(): override_ipex_math() set_cuda_memory_limit() set_cuda_tunable() set_cudnn_params() set_sdpa_params() set_dtype() test_triton() if backend == 'openvino': from modules.intel.openvino import get_device as get_raw_openvino_device device_name = get_raw_openvino_device() else: device_name = torch.device(get_optimal_device_name()) try: # tunable = torch._C._jit_get_tunable_op_enabled() # pylint: disable=protected-access tunable = [torch.cuda.tunable.is_enabled(), torch.cuda.tunable.tuning_is_enabled()] except Exception: tunable = [False, False] log.info(f'Torch parameters: backend={backend} device={device_name} config={opts.cuda_dtype} dtype={dtype} context={inference_context.__name__} nohalf={opts.no_half} nohalfvae={opts.no_half_vae} upcast={opts.upcast_sampling} deterministic={opts.cudnn_deterministic} tunable={tunable} fp16={"pass" if fp16_ok else "fail"} bf16={"pass" if bf16_ok else "fail"} triton={"pass" if triton_ok else "fail"} optimization="{opts.cross_attention_optimization}"') def randn(seed, shape=None): torch.manual_seed(seed) if backend == 'ipex': torch.xpu.manual_seed_all(seed) if shape is None: return None if device.type == 'mps': return torch.randn(shape, device=cpu).to(device) elif opts.diffusers_generator_device == "CPU": return torch.randn(shape, device=cpu) else: return torch.randn(shape, device=device) def randn_without_seed(shape): if device.type == 'mps': return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def autocast(disable=False): if disable or dtype == torch.float32: return contextlib.nullcontext() if backend == 'directml': return torch.dml.amp.autocast(dtype) if cuda_ok: return torch.autocast("cuda") else: return torch.autocast("cpu") def without_autocast(disable=False): if disable: return contextlib.nullcontext() if backend == 'directml': return torch.dml.amp.autocast(enabled=False) if torch.is_autocast_enabled() else contextlib.nullcontext() # pylint: disable=unexpected-keyword-arg if cuda_ok: return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() else contextlib.nullcontext() else: return torch.autocast("cpu", enabled=False) if torch.is_autocast_enabled() else contextlib.nullcontext() class NansException(Exception): pass def test_for_nans(x, where): if opts.disable_nan_check: return if not torch.all(torch.isnan(x)).item(): return if where == "unet": message = "A tensor with all NaNs was produced in Unet." if not opts.no_half: message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." if not opts.no_half and not opts.no_half_vae: message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this." else: message = "A tensor with all NaNs was produced." message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) def normalize_device(dev): if torch.device(dev).type in {"cpu", "mps", "meta"}: return torch.device(dev) if torch.device(dev).index is None: return torch.device(str(dev), index=0) return torch.device(dev) def same_device(d1, d2): if torch.device(d1).type != torch.device(d2).type: return False return normalize_device(d1) == normalize_device(d2) ================================================ FILE: modules/devices_mac.py ================================================ import platform from packaging import version import torch from modules.sd_hijack_utils import CondFunc cumsum_needs_int_fix = False # has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def check_for_mps() -> bool: if not getattr(torch, 'has_mps', False): return False try: torch.zeros(1).to(torch.device("mps")) return True except Exception: return False has_mps = check_for_mps() # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): # pylint: disable=redefined-builtin if input.device.type == 'mps': output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) elif output_dtype == torch.bool or (cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16)): return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) if has_mps: # MPS fix for randn in torchsde CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) if version.parse(torch.__version__) < version.parse("1.13"): # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), lambda _, self, *args, **kwargs: self.device.type != 'mps' and ((args and isinstance(args[0], torch.device) and args[0].type == 'mps') or (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))) # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) # pylint: disable=unnecessary-lambda-assignment CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 if platform.processor() == 'i386': for funcName in ['torch.argmax', 'torch.Tensor.argmax']: CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') ================================================ FILE: modules/dml/Generator.py ================================================ from typing import Optional import torch class Generator(torch.Generator): def __init__(self, device: Optional[torch.device] = None): super().__init__("cpu") ================================================ FILE: modules/dml/__init__.py ================================================ import platform from typing import NamedTuple, Callable, Optional import torch from modules.errors import log from modules.sd_hijack_utils import CondFunc memory_providers = ["None", "atiadlxx (AMD only)"] default_memory_provider = "None" if platform.system() == "Windows": memory_providers.append("Performance Counter") default_memory_provider = "Performance Counter" do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment def _set_memory_provider(): from modules.shared import opts, cmd_opts if opts.directml_memory_provider == "Performance Counter": from .backend import pdh_mem_get_info from .memory import MemoryProvider torch.dml.mem_get_info = pdh_mem_get_info if torch.dml.memory_provider is not None: del torch.dml.memory_provider torch.dml.memory_provider = MemoryProvider() elif opts.directml_memory_provider == "atiadlxx (AMD only)": device_name = torch.dml.get_device_name(cmd_opts.device_id) if "AMD" not in device_name and "Radeon" not in device_name: log.warning(f"Memory stats provider is changed to None because the current device is not AMDGPU. Current Device: {device_name}") opts.directml_memory_provider = "None" _set_memory_provider() return from .backend import amd_mem_get_info torch.dml.mem_get_info = amd_mem_get_info else: from .backend import mem_get_info torch.dml.mem_get_info = mem_get_info torch.cuda.mem_get_info = torch.dml.mem_get_info def directml_init(): try: from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports # Alternative of torch.cuda for DirectML. torch.dml = DirectML torch.cuda.is_available = lambda: False torch.cuda.device = torch.dml.device torch.cuda.device_count = torch.dml.device_count torch.cuda.current_device = torch.dml.current_device torch.cuda.get_device_name = torch.dml.get_device_name torch.cuda.get_device_properties = torch.dml.get_device_properties torch.cuda.empty_cache = do_nothing torch.cuda.ipc_collect = do_nothing torch.cuda.memory_stats = torch.dml.memory_stats torch.cuda.mem_get_info = torch.dml.mem_get_info torch.cuda.memory_allocated = torch.dml.memory_allocated torch.cuda.max_memory_allocated = torch.dml.max_memory_allocated torch.cuda.reset_peak_memory_stats = torch.dml.reset_peak_memory_stats torch.cuda.utilization = lambda: 0 torch.Tensor.directml = lambda self: self.to(torch.dml.current_device()) except Exception as e: log.error(f'DirectML initialization failed: {e}') return False, e return True, None def directml_do_hijack(): import modules.dml.hijack # pylint: disable=unused-import from modules.devices import device CondFunc('torch.Generator', lambda orig_func, device = None: orig_func("cpu"), lambda orig_func, device = None: True) if not torch.dml.has_float64_support(device): torch.Tensor.__str__ = do_nothing_with_self CondFunc('torch.from_numpy', lambda orig_func, *args, **kwargs: orig_func(args[0].astype('float32')), lambda *args, **kwargs: args[1].dtype == float) _set_memory_provider() class OverrideItem(NamedTuple): value: str condition: Optional[Callable] message: Optional[str] opts_override_table = { "diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"), } def directml_override_opts(): from modules import shared if shared.cmd_opts.experimental: return count = 0 for key in opts_override_table: item = opts_override_table[key] if getattr(shared.opts, key) != item.value and (item.condition is None or item.condition(shared.opts)): count += 1 setattr(shared.opts, key, item.value) shared.log.warning(f'Overriding: {key}={item.value} {item.message if item.message is not None else ""}') if count > 0: shared.log.info(f'Options override: count={count}. If you want to keep them from overriding, run with --experimental argument.') _set_memory_provider() ================================================ FILE: modules/dml/amp/__init__.py ================================================ from .autocast_mode import autocast ================================================ FILE: modules/dml/amp/autocast_mode.py ================================================ import importlib from typing import Any, Optional import torch ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"] supported_cast_pairs = { torch.float16: (torch.float32,), torch.float32: (torch.float16,), } def forward(op, args: tuple, kwargs: dict): if not torch.dml.is_autocast_enabled: return op(*args, **kwargs) args = list(map(cast, args)) for kwarg in kwargs: kwargs[kwarg] = cast(kwargs[kwarg]) return op(*args, **kwargs) def cast(tensor: torch.Tensor): if not torch.is_tensor(tensor): return tensor dtype: torch.dtype = tensor.dtype if dtype not in supported_cast_pairs or (torch.dml.autocast_gpu_dtype != dtype and torch.dml.autocast_gpu_dtype not in supported_cast_pairs[dtype]): return tensor return tensor.type(torch.dml.autocast_gpu_dtype) def cond(op: str): if isinstance(op, str): func_path = op.split('.') for i in range(len(func_path)-1, -1, -1): try: resolved_obj = importlib.import_module('.'.join(func_path[:i])) break except ImportError: pass for attr_name in func_path[i:-1]: resolved_obj = getattr(resolved_obj, attr_name) op = getattr(resolved_obj, func_path[-1]) setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs)) for o in ops: cond(o) class autocast: prev: bool fast_dtype: torch.dtype = torch.float16 prev_fast_dtype: torch.dtype def __init__(self, dtype: Optional[torch.dtype] = torch.float16): self.fast_dtype = dtype def __enter__(self): self.prev = torch.dml.is_autocast_enabled self.prev_fast_dtype = torch.dml.autocast_gpu_dtype torch.dml.is_autocast_enabled = True torch.dml.autocast_gpu_dtype = self.fast_dtype def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): torch.dml.is_autocast_enabled = self.prev torch.dml.autocast_gpu_dtype = self.prev_fast_dtype ================================================ FILE: modules/dml/backend.py ================================================ # pylint: disable=no-member,no-self-argument,no-method-argument from typing import Optional, Callable import torch import torch_directml # pylint: disable=import-error import modules.dml.amp as amp from .utils import rDevice, get_device from .device import Device from .Generator import Generator from .device_properties import DeviceProperties def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: from .memory_amd import AMDMemoryProvider return AMDMemoryProvider.mem_get_info(get_device(device).index) def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: mem_info = DirectML.memory_provider.get_memory(get_device(device).index) return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"]) def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument return (8589934592, 8589934592) class DirectML: amp = amp device = Device Generator = Generator context_device: Optional[torch.device] = None is_autocast_enabled = False autocast_gpu_dtype = torch.float16 memory_provider = None def is_available() -> bool: return torch_directml.is_available() def is_directml_device(device: torch.device) -> bool: return device.type == "privateuseone" def has_float64_support(device: Optional[rDevice]=None) -> bool: return torch_directml.has_float64_support(get_device(device).index) def device_count() -> int: return torch_directml.device_count() def current_device() -> torch.device: return DirectML.context_device or DirectML.default_device() def default_device() -> torch.device: return torch_directml.device(torch_directml.default_device()) def get_device_string(device: Optional[rDevice]=None) -> str: return f"privateuseone:{get_device(device).index}" def get_device_name(device: Optional[rDevice]=None) -> str: return torch_directml.device_name(get_device(device).index) def get_device_properties(device: Optional[rDevice]=None) -> DeviceProperties: return DeviceProperties(get_device(device)) def memory_stats(device: Optional[rDevice]=None): return { "num_ooms": 0, "num_alloc_retries": 0, } mem_get_info: Callable = mem_get_info def memory_allocated(device: Optional[rDevice]=None) -> int: return sum(torch_directml.gpu_memory(get_device(device).index)) * (1 << 20) def max_memory_allocated(device: Optional[rDevice]=None): return DirectML.memory_allocated(device) # DirectML does not empty GPU memory def reset_peak_memory_stats(device: Optional[rDevice]=None): return ================================================ FILE: modules/dml/device.py ================================================ from typing import Optional import torch from .utils import rDevice, get_device class Device: idx: int def __enter__(self, device: Optional[rDevice]=None): torch.dml.context_device = get_device(device) self.idx = torch.dml.context_device.index def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init self.idx = get_device(device).index def __exit__(self, t, v, tb): torch.dml.context_device = None ================================================ FILE: modules/dml/device_properties.py ================================================ import torch class DeviceProperties: type: str = "directml" name: str major: int = 0 minor: int = 0 total_memory: int multi_processor_count: int = 1 def __init__(self, device: torch.device): self.name = torch.dml.get_device_name(device) self.total_memory = torch.dml.mem_get_info(device)[0] def __str__(self): return f"DeviceProperties(name='{self.name}', total_memory='{self.total_memory}')" def __repr__(self): return f"DeviceProperties(name='{self.name}', total_memory='{self.total_memory}')" ================================================ FILE: modules/dml/hijack/__init__.py ================================================ import modules.dml.hijack.torch import modules.dml.hijack.realesrgan_model import modules.dml.hijack.transformers import modules.dml.hijack.tomesd ================================================ FILE: modules/dml/hijack/realesrgan_model.py ================================================ import math import torch from modules.postprocess.realesrgan_model_arch import RealESRGANer from installer import log # DML Solution: Some of contents of output tensor turn to 0 after Extended Slices. Move it to cpu. def tile_process(self): batch, channel, height, width = self.img.shape output_height = height * self.scale output_width = width * self.scale output_shape = (batch, channel, output_height, output_width) # start with black image self.output = self.img.new_zeros(output_shape) tiles_x = math.ceil(width / self.tile_size) tiles_y = math.ceil(height / self.tile_size) # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): # extract tile from input image ofs_x = x * self.tile_size ofs_y = y * self.tile_size # input tile area on total image input_start_x = ofs_x input_end_x = min(ofs_x + self.tile_size, width) input_start_y = ofs_y input_end_y = min(ofs_y + self.tile_size, height) # input tile area on total image with padding input_start_x_pad = max(input_start_x - self.tile_pad, 0) input_end_x_pad = min(input_end_x + self.tile_pad, width) input_start_y_pad = max(input_start_y - self.tile_pad, 0) input_end_y_pad = min(input_end_y + self.tile_pad, height) # input tile dimensions input_tile_width = input_end_x - input_start_x input_tile_height = input_end_y - input_start_y _tile_idx = y * tiles_x + x + 1 input_tile = self.img[0:self.img.shape[0], 0:self.img.shape[1], input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] # upscale tile try: with torch.no_grad(): output_tile = self.model(input_tile) except Exception as e: log.error(f'Upscale error: type=R-ESRGAN {e}') # output tile area on total image output_start_x = input_start_x * self.scale output_end_x = input_end_x * self.scale output_start_y = input_start_y * self.scale output_end_y = input_end_y * self.scale # output tile area without padding output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale output_end_x_tile = output_start_x_tile + input_tile_width * self.scale output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale output_end_y_tile = output_start_y_tile + input_tile_height * self.scale self.output = self.output.cpu() # put tile into output image self.output[0:self.output.shape[0], 0:self.output.shape[1], output_start_y:output_end_y, output_start_x:output_end_x] = output_tile.cpu()[0:output_tile.shape[0], 0:output_tile.shape[1], output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile] self.output = self.output.to(output_tile.device) RealESRGANer.tile_process = tile_process ================================================ FILE: modules/dml/hijack/tomesd.py ================================================ from typing import Type import torch from modules.dml.hijack.utils import catch_nan def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: class ToMeBlock(block_class): # Save for unpatching later _parent = block_class def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: m_a, m_c, m_m, u_a, u_c, u_m = tomesd.patch.compute_merge(x, self._tome_info) # This is where the meat of the computation happens x = u_a(self.attn1(m_a(self.norm1(x)), context=context if self.disable_self_attn else None)) + x x = catch_nan(lambda: (u_c(self.attn2(m_c(self.norm2(x)), context=context)) + x)) x = u_m(self.ff(m_m(self.norm3(x)))) + x return x return ToMeBlock try: import tomesd tomesd.patch.make_tome_block = make_tome_block except Exception: pass ================================================ FILE: modules/dml/hijack/torch.py ================================================ import torch from modules.sd_hijack_utils import CondFunc CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'privateuseone') # https://github.com/microsoft/DirectML/issues/400 CondFunc('torch.Tensor.new', lambda orig, self, *args, **kwargs: orig(self.cpu(), *args, **kwargs).to(self.device), lambda orig, self, *args, **kwargs: torch.dml.is_directml_device(self.device)) def cuda(self: torch.Tensor): return self.to(torch.dml.current_device()) torch.Tensor.cuda = cuda # https://github.com/lshqqytiger/stable-diffusion-webui-directml/issues/436 _pow_ = torch.Tensor.pow_ def pow_(self: torch.Tensor, *args, **kwargs): if self.dtype == torch.float64: return _pow_(self.cpu(), *args, **kwargs).to(self.device) return _pow_(self, *args, **kwargs) torch.Tensor.pow_ = pow_ _load = torch.load def load(f, map_location = "cpu", *args, **kwargs): if type(map_location) in (str, torch.device,): device = torch.device(map_location) if device.type == "privateuseone": data = _load(f, *args, map_location="cpu", **kwargs) for k in data: for weight in data[k]: data[k][weight] = data[k][weight].to(device) return data return _load(f, *args, map_location=map_location, **kwargs) torch.load = load ================================================ FILE: modules/dml/hijack/transformers.py ================================================ from typing import Optional import torch import transformers.models.clip.modeling_clip # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape min = torch.tensor(torch.finfo(dtype).min, device="cpu") mask = torch.full((tgt_len, tgt_len), min, device=device) # https://discord.com/channels/1101998836328697867/1127441997184122920 mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def CLIPTextEmbeddings_forward( self: transformers.models.clip.modeling_clip.CLIPTextEmbeddings, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: from modules.devices import dtype seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids).type(dtype) # Type correction. position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings transformers.models.clip.modeling_clip._make_causal_mask = _make_causal_mask transformers.models.clip.modeling_clip.CLIPTextEmbeddings.forward = CLIPTextEmbeddings_forward ================================================ FILE: modules/dml/hijack/utils.py ================================================ import torch from typing import Callable from modules.shared import log, opts def catch_nan(func: Callable[[], torch.Tensor]): if not opts.directml_catch_nan: return func() tries = 0 tensor = func() while tensor.isnan().sum() != 0 and tries < 10: if tries == 0: log.warning("NaN is produced. Retry with same values...") tries += 1 tensor = func() if tensor.isnan().sum() != 0: log.error("Failed to cover NaN.") return tensor ================================================ FILE: modules/dml/memory.py ================================================ from os import getpid from collections import defaultdict from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path class MemoryProvider: hQuery: HQuery hCounters: defaultdict[str, list[HCounter]] def __init__(self): self.hQuery = HQuery() self.hCounters = defaultdict(list) def get_memory(self, device_id: int) -> dict[str, int]: if len(self.hCounters) == 0: pid = getpid() paths_dedicated = expand_wildcard_path(f"\\GPU Process Memory(pid_{pid}_*_phys_{device_id})\\Dedicated Usage") paths_committed = expand_wildcard_path(f"\\GPU Process Memory(pid_{pid}_*_phys_{device_id})\\Total Committed") for path in paths_dedicated: self.hCounters["dedicated_usage"].append(self.hQuery.add_counter(path)) for path in paths_committed: self.hCounters["total_committed"].append(self.hQuery.add_counter(path)) self.hQuery.collect_data() result = defaultdict(int) for key in self.hCounters: for hCounter in self.hCounters[key]: result[key] += hCounter.get_formatted_value(int) return dict(result) def __del__(self): self.hQuery.close() ================================================ FILE: modules/dml/memory_amd/__init__.py ================================================ from .driver.atiadlxx import ATIADLxx class AMDMemoryProvider: driver: ATIADLxx = ATIADLxx() @staticmethod def mem_get_info(index): usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20) return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize) ================================================ FILE: modules/dml/memory_amd/driver/atiadlxx.py ================================================ import ctypes as C from modules.dml.memory_amd.driver.atiadlxx_apis import ADL2_Main_Control_Create, ADL_Main_Memory_Alloc, ADL2_Adapter_NumberOfAdapters_Get, ADL2_Adapter_AdapterInfo_Get, ADL2_Adapter_MemoryInfo2_Get, ADL2_Adapter_DedicatedVRAMUsage_Get, ADL2_Adapter_VRAMUsage_Get from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, AdapterInfo, LPAdapterInfo, ADLMemoryInfo2 from modules.dml.memory_amd.driver.atiadlxx_defines import ADL_OK class ATIADLxx: iHyperMemorySize = 0 def __init__(self): self.context = ADL_CONTEXT_HANDLE() ADL2_Main_Control_Create(ADL_Main_Memory_Alloc, 1, C.byref(self.context)) num_adapters = C.c_int(-1) ADL2_Adapter_NumberOfAdapters_Get(self.context, C.byref(num_adapters)) AdapterInfoArray = (AdapterInfo * num_adapters.value)() ADL2_Adapter_AdapterInfo_Get(self.context, C.cast(AdapterInfoArray, LPAdapterInfo), C.sizeof(AdapterInfoArray)) self.devices = [] busNumbers = [] for adapter in AdapterInfoArray: if adapter.iBusNumber not in busNumbers: # filter duplicate device self.devices.append(adapter) busNumbers.append(adapter.iBusNumber) self.iHyperMemorySize = self.get_memory_info2(0).iHyperMemorySize def get_memory_info2(self, adapterIndex: int) -> ADLMemoryInfo2: info = ADLMemoryInfo2() if ADL2_Adapter_MemoryInfo2_Get(self.context, adapterIndex, C.byref(info)) != ADL_OK: raise RuntimeError("ADL2: Failed to get MemoryInfo2") return info def get_dedicated_vram_usage(self, index: int) -> int: usage = C.c_int(-1) if ADL2_Adapter_DedicatedVRAMUsage_Get(self.context, self.devices[index].iAdapterIndex, C.byref(usage)) != ADL_OK: raise RuntimeError("ADL2: Failed to get DedicatedVRAMUsage") return usage.value def get_vram_usage(self, index: int) -> int: usage = C.c_int(-1) if ADL2_Adapter_VRAMUsage_Get(self.context, self.devices[index].iAdapterIndex, C.byref(usage)) != ADL_OK: raise RuntimeError("ADL2: Failed to get VRAMUsage") return usage.value ================================================ FILE: modules/dml/memory_amd/driver/atiadlxx_apis.py ================================================ import ctypes as C from platform import platform from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, LPAdapterInfo, ADLMemoryInfo2 if 'Windows' in platform(): atiadlxx = C.WinDLL("atiadlxx.dll") else: atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported. ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int) ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p)) @ADL_MAIN_MALLOC_CALLBACK def ADL_Main_Memory_Alloc(iSize): return C._malloc(iSize) @ADL_MAIN_FREE_CALLBACK def ADL_Main_Memory_Free(lpBuffer): if lpBuffer[0] is not None: C._free(lpBuffer[0]) lpBuffer[0] = None ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create ADL2_Main_Control_Create.restype = C.c_int ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE] ADL2_Adapter_NumberOfAdapters_Get = atiadlxx.ADL2_Adapter_NumberOfAdapters_Get ADL2_Adapter_NumberOfAdapters_Get.restype = C.c_int ADL2_Adapter_NumberOfAdapters_Get.argtypes = [ADL_CONTEXT_HANDLE, C.POINTER(C.c_int)] ADL2_Adapter_AdapterInfo_Get = atiadlxx.ADL2_Adapter_AdapterInfo_Get ADL2_Adapter_AdapterInfo_Get.restype = C.c_int ADL2_Adapter_AdapterInfo_Get.argtypes = [ADL_CONTEXT_HANDLE, LPAdapterInfo, C.c_int] ADL2_Adapter_MemoryInfo2_Get = atiadlxx.ADL2_Adapter_MemoryInfo2_Get ADL2_Adapter_MemoryInfo2_Get.restype = C.c_int ADL2_Adapter_MemoryInfo2_Get.argtypes = [ADL_CONTEXT_HANDLE, C.c_int, C.POINTER(ADLMemoryInfo2)] ADL2_Adapter_DedicatedVRAMUsage_Get = atiadlxx.ADL2_Adapter_DedicatedVRAMUsage_Get ADL2_Adapter_DedicatedVRAMUsage_Get.restype = C.c_int ADL2_Adapter_DedicatedVRAMUsage_Get.argtypes = [ADL_CONTEXT_HANDLE, C.c_int, C.POINTER(C.c_int)] ADL2_Adapter_VRAMUsage_Get = atiadlxx.ADL2_Adapter_VRAMUsage_Get ADL2_Adapter_VRAMUsage_Get.restype = C.c_int ADL2_Adapter_VRAMUsage_Get.argtypes = [ADL_CONTEXT_HANDLE, C.c_int, C.POINTER(C.c_int)] ================================================ FILE: modules/dml/memory_amd/driver/atiadlxx_defines.py ================================================ ADL_OK = 0 ================================================ FILE: modules/dml/memory_amd/driver/atiadlxx_structures.py ================================================ import ctypes as C class _ADLPMActivity(C.Structure): __slot__ = [ 'iActivityPercent', 'iCurrentBusLanes', 'iCurrentBusSpeed', 'iCurrentPerformanceLevel', 'iEngineClock', 'iMaximumBusLanes', 'iMemoryClock', 'iReserved', 'iSize', 'iVddc', ] _ADLPMActivity._fields_ = [ # pylint: disable=protected-access ('iActivityPercent', C.c_int), ('iCurrentBusLanes', C.c_int), ('iCurrentBusSpeed', C.c_int), ('iCurrentPerformanceLevel', C.c_int), ('iEngineClock', C.c_int), ('iMaximumBusLanes', C.c_int), ('iMemoryClock', C.c_int), ('iReserved', C.c_int), ('iSize', C.c_int), ('iVddc', C.c_int), ] ADLPMActivity = _ADLPMActivity class _ADLMemoryInfo2(C.Structure): __slot__ = [ 'iHyperMemorySize', 'iInvisibleMemorySize', 'iMemoryBandwidth', 'iMemorySize', 'iVisibleMemorySize', 'strMemoryType' ] _ADLMemoryInfo2._fields_ = [ # pylint: disable=protected-access ('iHyperMemorySize', C.c_longlong), ('iInvisibleMemorySize', C.c_longlong), ('iMemoryBandwidth', C.c_longlong), ('iMemorySize', C.c_longlong), ('iVisibleMemorySize', C.c_longlong), ('strMemoryType', C.c_char * 256) ] ADLMemoryInfo2 = _ADLMemoryInfo2 class _AdapterInfo(C.Structure): __slot__ = [ 'iSize', 'iAdapterIndex', 'strUDID', 'iBusNumber', 'iDeviceNumber', 'iFunctionNumber', 'iVendorID', 'strAdapterName', 'strDisplayName', 'iPresent', 'iExist', 'strDriverPath', 'strDriverPathExt', 'strPNPString', 'iOSDisplayIndex', ] _AdapterInfo._fields_ = [ # pylint: disable=protected-access ('iSize', C.c_int), ('iAdapterIndex', C.c_int), ('strUDID', C.c_char * 256), ('iBusNumber', C.c_int), ('iDeviceNumber', C.c_int), ('iFunctionNumber', C.c_int), ('iVendorID', C.c_int), ('strAdapterName', C.c_char * 256), ('strDisplayName', C.c_char * 256), ('iPresent', C.c_int), ('iExist', C.c_int), ('strDriverPath', C.c_char * 256), ('strDriverPathExt', C.c_char * 256), ('strPNPString', C.c_char * 256), ('iOSDisplayIndex', C.c_int) ] AdapterInfo = _AdapterInfo LPAdapterInfo = C.POINTER(_AdapterInfo) ADL_CONTEXT_HANDLE = C.c_void_p ================================================ FILE: modules/dml/pdh/__init__.py ================================================ from ctypes import byref, cast, c_size_t from ctypes.wintypes import LPCWSTR, DWORD, WCHAR from typing import NamedTuple, TypeVar from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W from .defines import PDH_FMT_LARGE, PDH_FMT_DOUBLE, PDH_FMT_NOSCALE, PDH_NOEXPANDCOUNTERS, PDH_MORE_DATA, PDH_OK from .msvcrt import malloc from .errors import PDHError class __InternalAbstraction(NamedTuple): flag: int attr_name: str _type_map = { int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"), float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"), } def expand_wildcard_path(path: str) -> list[str]: listLength = DWORD(0) if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA: raise PDHError("Something went wrong.") expanded = (WCHAR * listLength.value)() if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK: raise PDHError(f"Couldn't expand wildcard path '{path}'") result = [] cur = "" for c in expanded: if c == '\0': result.append(cur) cur = "" else: cur += c result.pop() return result T = TypeVar("T", *_type_map.keys()) class HCounter(PDH_HCOUNTER): def get_formatted_value(self, typ: T) -> T: if typ not in _type_map: raise PDHError(f"Invalid value type: {typ}") flag, attr_name = _type_map[typ] value = PDH_FMT_COUNTERVALUE() if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK: raise PDHError("Couldn't get formatted counter value.") return getattr(value.u, attr_name) def get_formatted_dict(self, typ: T) -> dict[str, T]: if typ not in _type_map: raise PDHError(f"Invalid value type: {typ}") flag, attr_name = _type_map[typ] bufferSize = DWORD(0) itemCount = DWORD(0) if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA: raise PDHError("Something went wrong.") itemBuffer = cast(malloc(c_size_t(bufferSize.value)), PPDH_FMT_COUNTERVALUE_ITEM_W) if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), itemBuffer) != PDH_OK: raise PDHError("Couldn't get formatted counter array.") result: dict[str, T] = {} for i in range(0, itemCount.value): item = itemBuffer[i] result[item.szName] = getattr(item.FmtValue.u, attr_name) return result class HQuery(PDH_HQUERY): def __init__(self): super().__init__() if PdhOpenQueryW(None, None, byref(self)) != PDH_OK: raise PDHError("Couldn't open PDH query.") def add_counter(self, path: str) -> HCounter: hCounter = HCounter() if PdhAddEnglishCounterW(self, LPCWSTR(path), None, byref(hCounter)) != PDH_OK: raise PDHError("Couldn't add counter query.") return hCounter def collect_data(self): if PdhCollectQueryData(self) != PDH_OK: raise PDHError("Couldn't collect query data.") def close(self): if PdhCloseQuery(self) != PDH_OK: raise PDHError("Couldn't close PDH query.") ================================================ FILE: modules/dml/pdh/apis.py ================================================ from ctypes import CDLL, POINTER from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD from typing import Callable from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR pdh = CDLL("pdh.dll") PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW PdhExpandWildCardPathW.restype = PDH_FUNCTION PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD] PdhOpenQueryW: Callable = pdh.PdhOpenQueryW PdhOpenQueryW.restype = PDH_FUNCTION PdhOpenQueryW.argtypes = [LPCWSTR, DWORD_PTR, POINTER(PDH_HQUERY)] PdhAddEnglishCounterW: Callable = pdh.PdhAddEnglishCounterW PdhAddEnglishCounterW.restype = PDH_FUNCTION PdhAddEnglishCounterW.argtypes = [PDH_HQUERY, LPCWSTR, DWORD_PTR, POINTER(PDH_HCOUNTER)] PdhCollectQueryData: Callable = pdh.PdhCollectQueryData PdhCollectQueryData.restype = PDH_FUNCTION PdhCollectQueryData.argtypes = [PDH_HQUERY] PdhGetFormattedCounterValue: Callable = pdh.PdhGetFormattedCounterValue PdhGetFormattedCounterValue.restype = PDH_FUNCTION PdhGetFormattedCounterValue.argtypes = [PDH_HCOUNTER, DWORD, LPDWORD, PPDH_FMT_COUNTERVALUE] PdhGetFormattedCounterArrayW: Callable = pdh.PdhGetFormattedCounterArrayW PdhGetFormattedCounterArrayW.restype = PDH_FUNCTION PdhGetFormattedCounterArrayW.argtypes = [PDH_HCOUNTER, DWORD, LPDWORD, LPDWORD, PPDH_FMT_COUNTERVALUE_ITEM_W] PdhCloseQuery: Callable = pdh.PdhCloseQuery PdhCloseQuery.restype = PDH_FUNCTION PdhCloseQuery.argtypes = [PDH_HQUERY] ================================================ FILE: modules/dml/pdh/defines.py ================================================ from ctypes import c_int, POINTER from ctypes.wintypes import DWORD, WCHAR PDH_FUNCTION = c_int PDH_OK = 0x00000000 PDH_MORE_DATA = -2147481646#0x800007D2 DWORD_PTR = POINTER(DWORD) PWSTR = POINTER(WCHAR) PZZWSTR = POINTER(WCHAR) PDH_NOEXPANDCOUNTERS = 1 PDH_NOEXPANDINSTANCES = 2 PDH_REFRESHCOUNTERS = 4 PDH_FMT_LONG = 0x00000100 PDH_FMT_DOUBLE = 0x00000200 PDH_FMT_LARGE = 0x00000400 PDH_FMT_NOSCALE = 0x00001000 PDH_FMT_1000 = 0x00002000 PDH_FMT_NOCAP100 = 0x00008000 ================================================ FILE: modules/dml/pdh/errors.py ================================================ class PDHError(Exception): def __init__(self, message: str): super().__init__(message) ================================================ FILE: modules/dml/pdh/msvcrt.py ================================================ from ctypes import CDLL, c_void_p, c_size_t msvcrt = CDLL("msvcrt") malloc = msvcrt.malloc malloc.restype = c_void_p malloc.argtypes = [c_size_t] free = msvcrt.free free.restype = None free.argtypes = [c_void_p] ================================================ FILE: modules/dml/pdh/structures.py ================================================ from ctypes import Union, c_double, c_longlong, Structure, POINTER from ctypes.wintypes import HANDLE, LONG, LPCSTR, LPCWSTR, DWORD, LPWSTR PDH_HQUERY = HANDLE PDH_HCOUNTER = HANDLE class PDH_FMT_COUNTERVALUE_U(Union): _fields_ = [ ("longValue", LONG), ("doubleValue", c_double), ("largeValue", c_longlong), ("AnsiStringValue", LPCSTR), ("WideStringValue", LPCWSTR), ] longValue: int doubleValue: float largeValue: int AnsiStringValue: LPCSTR WideStringValue: LPCWSTR class PDH_FMT_COUNTERVALUE(Structure): _anonymous_ = ("u",) _fields_ = [ ("CStatus", DWORD), ("u", PDH_FMT_COUNTERVALUE_U), ] CStatus: DWORD u: PDH_FMT_COUNTERVALUE_U PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE) class PDH_FMT_COUNTERVALUE_ITEM_W(Structure): _fields_ = [ ("szName", LPWSTR), ("FmtValue", PDH_FMT_COUNTERVALUE), ] szName: str FmtValue: PDH_FMT_COUNTERVALUE PPDH_FMT_COUNTERVALUE_ITEM_W = POINTER(PDH_FMT_COUNTERVALUE_ITEM_W) ================================================ FILE: modules/dml/utils.py ================================================ from typing import Optional, Union import torch rDevice = Union[torch.device, int] def get_device(device: Optional[rDevice]=None) -> torch.device: if device is None: device = torch.dml.current_device() return torch.device(device) ================================================ FILE: modules/errorlimiter.py ================================================ from __future__ import annotations from contextlib import contextmanager from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Iterable class ErrorLimiterTrigger(BaseException): # Use BaseException to avoid being caught by "except Exception:". def __init__(self, name: str, *args): super().__init__(*args) self.name = name class ErrorLimiterAbort(RuntimeError): def __init__(self, msg: str): super().__init__(msg) class ErrorLimiter: _store: dict[str, int] = {} @classmethod def start(cls, name: str, limit: int = 5): cls._store[name] = limit @classmethod def notify(cls, name: str | Iterable[str]): # Can be manually triggered if execution is spread across multiple files if isinstance(name, str): name = (name,) for key in name: if key in cls._store.keys(): cls._store[key] = cls._store[key] - 1 if cls._store[key] <= 0: raise ErrorLimiterTrigger(key) @classmethod def end(cls, name: str): cls._store.pop(name) @contextmanager def limit_errors(name: str, limit: int = 5): """Limiter for aborting execution after being triggered a specified number of times (default 5). >>> with limit_errors("identifier", limit=5) as elimit: >>> while do_thing(): >>> if (something_bad): >>> print("Something bad happened") >>> elimit() # In this example, raises ErrorLimiterAbort on the 5th call >>> try: >>> something_broken() >>> except Exception: >>> print("Encountered an exception") >>> elimit() # Count is shared across all calls Args: name (str): Identifier. limit (int, optional): Abort after `limit` number of triggers. Defaults to 5. Raises: ErrorLimiterAbort: Subclass of RuntimeException. Yields: Callable: Notification function to indicate that an error occurred. """ try: ErrorLimiter.start(name, limit) yield lambda: ErrorLimiter.notify(name) except ErrorLimiterTrigger as e: raise ErrorLimiterAbort(f"HALTING. Too many errors during '{e.name}'") from None finally: ErrorLimiter.end(name) ================================================ FILE: modules/errors.py ================================================ import logging import warnings from installer import get_log, get_console, setup_logging, install_traceback from modules.errorlimiter import ErrorLimiterAbort log = get_log() setup_logging() install_traceback() already_displayed = {} def install(suppress=[]): warnings.filterwarnings("ignore", category=UserWarning) install_traceback(suppress=suppress) logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s') def display(e: Exception, task: str, suppress=[]): if isinstance(e, ErrorLimiterAbort): return log.critical(f"{task or 'error'}: {type(e).__name__}") """ trace = traceback.format_exc() log.error(trace) for line in traceback.format_tb(e.__traceback__): log.error(repr(line)) console = get_console() console.print_exception(show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width) """ log.traceback(e, suppress=suppress) def display_once(e: Exception, task): if task in already_displayed: return display(e, task) already_displayed[task] = 1 def run(code, task: str): try: code() except Exception as e: display(e, task) def exception(suppress=[]): console = get_console() console.print_exception(show_locals=False, max_frames=16, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200])) def profile(profiler, msg: str, n: int = 16): profiler.disable() import io import pstats stream = io.StringIO() # pylint: disable=abstract-class-instantiated p = pstats.Stats(profiler, stream=stream) p.sort_stats(pstats.SortKey.CUMULATIVE) p.print_stats(200) # p.print_title() # p.print_call_heading(10, 'time') # p.print_callees(10) # p.print_callers(10) profiler = None lines = stream.getvalue().split('\n') lines = [x for x in lines if ' datetime: # If Python minimum version is 3.11+, this function can be replaced with datetime.fromisoformat() trimmed = time_string.rstrip("Z") if "." in trimmed: trimmed = trimmed.split(".")[0] match len(trimmed): case 16: return datetime.strptime(trimmed, "%Y-%m-%dT%H:%M").replace(tzinfo=timezone.utc) case 19: return datetime.strptime(trimmed, "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) case _: raise ValueError(f"Unexpected time string format: '{time_string}'") def format_dt(d: datetime, seconds = False) -> str: if d.tzinfo is None: return d.strftime('%Y-%m-%d %H:%M') if seconds: return d.astimezone(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') return d.astimezone(timezone.utc).strftime('%Y-%m-%d %H:%M') def ts2utc(timestamp: int) -> datetime: try: return datetime.fromtimestamp(timestamp, timezone.utc) except Exception: return "unknown" def active(): if shared.opts.disable_all_extensions == "all": return [] elif shared.opts.disable_all_extensions == "user": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] def temp_disable_extensions(): disable_safe = [ 'sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-agent-scheduler', 'clip-interrogator-ext', 'stable-diffusion-webui-images-browser', ] disable_diffusers = [ 'sd-webui-controlnet', 'multidiffusion-upscaler-for-automatic1111', 'a1111-sd-webui-lycoris', 'sd-webui-animatediff', ] disable_themes = [ 'sd-webui-lobe-theme', 'cozy-nest', 'sdnext-modernui', ] disabled = [] if shared.cmd_opts.theme is not None: theme_name = shared.cmd_opts.theme else: theme_name = f'{shared.opts.theme_type.lower()}/{shared.opts.gradio_theme}' if theme_name == 'lobe': disable_themes.remove('sd-webui-lobe-theme') elif theme_name == 'cozy-nest' or theme_name == 'cozy': disable_themes.remove('cozy-nest') elif '/' not in theme_name: # set default themes per type if theme_name == 'standard' or theme_name == 'default': theme_name = 'standard/black-teal' if theme_name == 'modern': theme_name = 'modern/Default' if theme_name == 'gradio': theme_name = 'gradio/default' if theme_name == 'huggingface': theme_name = 'huggingface/blaaa' if theme_name.lower().startswith('standard') or theme_name.lower().startswith('default'): shared.opts.data['theme_type'] = 'Standard' shared.opts.data['gradio_theme'] = theme_name[9:] elif theme_name.lower().startswith('modern'): shared.opts.data['theme_type'] = 'Modern' shared.opts.data['gradio_theme'] = theme_name[7:] disable_themes.remove('sdnext-modernui') elif theme_name.lower().startswith('huggingface') or theme_name.lower().startswith('gradio') or theme_name.lower().startswith('none'): shared.opts.data['theme_type'] = 'None' shared.opts.data['gradio_theme'] = theme_name else: shared.log.error(f'UI theme invalid: theme="{theme_name}" available={["standard/*", "modern/*", "none/*"]} fallback="standard/black-teal"') shared.opts.data['theme_type'] = 'Standard' shared.opts.data['gradio_theme'] = 'black-teal' for ext in disable_themes: if ext.lower() not in shared.opts.disabled_extensions: disabled.append(ext) if shared.cmd_opts.safe: for ext in disable_safe: if ext.lower() not in shared.opts.disabled_extensions: disabled.append(ext) for ext in disable_diffusers: if ext.lower() not in shared.opts.disabled_extensions: disabled.append(ext) disabled.append('Lora') shared.cmd_opts.controlnet_loglevel = 'WARNING' return disabled class Extension: def __init__(self, name, path, enabled=True, is_builtin=False): self.name = name self.git_name = '' self.path = path self.enabled = enabled self.status = '' self.can_update = False self.is_builtin = is_builtin self.commit_hash = '' self.commit_date = None self.version = '' self.description = '' self.branch = None self.remote = None self.have_info_from_repo = False self.mtime = "2000-01-01T00:00Z" self.ctime = "2000-01-01T00:00Z" def read_info(self, force=False): if self.have_info_from_repo and not force: return self.have_info_from_repo = True repo = None self.mtime = datetime.fromtimestamp(os.path.getmtime(self.path)).isoformat() + 'Z' self.ctime = datetime.fromtimestamp(os.path.getctime(self.path)).isoformat() + 'Z' try: if os.path.exists(os.path.join(self.path, ".git")): repo = git.Repo(self.path) except Exception as e: errors.display(e, f'github info from {self.path}') if repo is None or repo.bare: self.remote = None else: try: self.status = 'unknown' if len(repo.remotes) == 0: shared.log.debug(f"Extension: no remotes info repo={self.name}") return self.git_name = repo.remotes.origin.url.split('.git')[0].split('/')[-1] self.description = repo.description if self.description is None or self.description.startswith("Unnamed repository"): self.description = "[No description]" self.remote = next(repo.remote().urls, None) head = repo.head.commit self.commit_date = repo.head.commit.committed_date try: if repo.active_branch: self.branch = repo.active_branch.name except Exception: self.branch = 'unknown' self.commit_hash = head.hexsha self.version = f"

{self.commit_hash[:8]}

{format_dt(ts2utc(self.commit_date))}

" except Exception as ex: shared.log.error(f"Extension: failed reading data from git repo={self.name}: {ex}") self.remote = None def list_files(self, subdir, extension): from modules import scripts_manager dirpath = os.path.join(self.path, subdir) if not os.path.isdir(dirpath): return [] res = [] for filename in sorted(os.listdir(dirpath)): if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"): continue priority = '50' if os.path.isfile(os.path.join(dirpath, "..", ".priority")): with open(os.path.join(dirpath, "..", ".priority"), "r", encoding="utf-8") as f: priority = str(f.read().strip()) res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority)) if priority != '50': shared.log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}') res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] return res def check_updates(self): try: repo = git.Repo(self.path) except Exception: self.can_update = False return for fetch in repo.remote().fetch(dry_run=True): if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True self.status = "new commits" return try: origin = repo.rev_parse('origin') if repo.head.commit != origin: self.can_update = True self.status = "behind HEAD" return except Exception: self.can_update = False self.status = "unknown (remote error)" return self.can_update = False self.status = "latest" def git_fetch(self, commit='origin'): repo = git.Repo(self.path) # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. repo.git.fetch(all=True) repo.git.reset('origin', hard=True) repo.git.reset(commit, hard=True) self.have_info_from_repo = False def list_extensions(): extensions.clear() if not os.path.isdir(extensions_dir): return if shared.opts.disable_all_extensions == "all" or shared.opts.disable_all_extensions == "user": shared.log.warning(f"Option set: Disable extensions: {shared.opts.disable_all_extensions}") extension_paths = [] extension_names = [] extension_folders = [extensions_builtin_dir] if shared.cmd_opts.safe else [extensions_builtin_dir, extensions_dir] for dirname in extension_folders: if not os.path.isdir(dirname): return for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): continue if extension_dirname in extension_names: shared.log.info(f'Skipping conflicting extension: {path}') continue extension_names.append(extension_dirname) extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) if shared.opts.theme_type == 'Modern' and 'sdnext-modernui' in shared.opts.disabled_extensions: shared.opts.disabled_extensions.remove('sdnext-modernui') disabled_extensions = [e.lower() for e in shared.opts.disabled_extensions + temp_disable_extensions()] for dirname, path, is_builtin in extension_paths: enabled = dirname.lower() not in disabled_extensions extension = Extension(name=dirname, path=path, enabled=enabled, is_builtin=is_builtin) extensions.append(extension) shared.log.debug(f'Extensions: disabled={[e.name for e in extensions if not e.enabled]}') ================================================ FILE: modules/extra_networks.py ================================================ import re import inspect from collections import defaultdict from modules import errors, shared extra_network_registry = {} def initialize(): extra_network_registry.clear() def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network def register_default_extra_networks(): from modules.ui_extra_networks_styles import ExtraNetworkStyles register_extra_network(ExtraNetworkStyles()) from modules.lora import lora_common, extra_networks_lora lora_common.extra_network_lora = extra_networks_lora.ExtraNetworkLora() register_extra_network(lora_common.extra_network_lora) class ExtraNetworkParams: def __init__(self, items=None): self.items = items or [] self.positional = [] self.named = {} for item in self.items: parts = item.split('=', 2) if isinstance(item, str) else [item] if len(parts) == 2: self.named[parts[0]] = parts[1] else: self.positional.append(item) class ExtraNetwork: def __init__(self, name): self.name = name def activate(self, p, params_list): """ Called by processing on every run. Whatever the extra network is meant to do should be activated here. Passes arguments related to this extra network in params_list. User passes arguments by specifying this in his prompt: Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments separated by colon. Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - in this case, all effects of this extra networks should be disabled. Can be called multiple times before deactivate() - each new call should override the previous call completely. For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: > "1girl, " params_list will be: [ ExtraNetworkParams(items=["agm", "1.1"]), ExtraNetworkParams(items=["ray"]) ] """ raise NotImplementedError def deactivate(self, p, force=False): """ Called at the end of processing for housekeeping. No need to do anything here. """ raise NotImplementedError def is_stepwise(en_obj): all_args = [] for en in en_obj: all_args.extend(en.positional[1:]) all_args.extend(en.named.values()) return any([len(str(x).split("@")) > 1 for x in all_args]) # noqa C419 # pylint: disable=use-a-generator def activate(p, extra_network_data=None, step=0, include=[], exclude=[]): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" if p.disable_extra_networks: return extra_network_data = extra_network_data or p.network_data # if extra_network_data is None or len(extra_network_data) == 0: # return stepwise = False for extra_network_args in extra_network_data.values(): stepwise = stepwise or is_stepwise(extra_network_args) functional = shared.opts.lora_functional if shared.opts.lora_force_diffusers and stepwise: shared.log.warning("Network load: type=LoRA method=composable loader=diffusers not compatible") stepwise = False shared.opts.data['lora_functional'] = stepwise or functional for extra_network_name, extra_network_args in extra_network_data.items(): extra_network = extra_network_registry.get(extra_network_name, None) if extra_network is None: errors.log.warning(f"Skipping unknown extra network: {extra_network_name}") continue try: signature = list(inspect.signature(extra_network.activate).parameters) if 'include' in signature and 'exclude' in signature: extra_network.activate(p, extra_network_args, step=step, include=include, exclude=exclude) else: extra_network.activate(p, extra_network_args, step=step) except Exception as e: errors.display(e, f"Activating network: type={extra_network_name} args:{extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): args = extra_network_data.get(extra_network_name, None) if args is not None: continue try: signature = list(inspect.signature(extra_network.activate).parameters) if 'include' in signature and 'exclude' in signature: extra_network.activate(p, [], include=include, exclude=exclude) else: extra_network.activate(p, []) except Exception as e: errors.display(e, f"Activating network: type={extra_network_name}") p.network_data = extra_network_data if stepwise: p.stepwise_lora = True shared.opts.data['lora_functional'] = functional def deactivate(p, extra_network_data=None, force=shared.opts.lora_force_reload): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" if p.disable_extra_networks: return extra_network_data = extra_network_data or p.network_data # if extra_network_data is None or len(extra_network_data) == 0: # return for extra_network_name in extra_network_data: extra_network = extra_network_registry.get(extra_network_name, None) if extra_network is None: continue try: extra_network.deactivate(p, force=force) except Exception as e: errors.display(e, f"deactivating extra network {extra_network_name}") for extra_network_name, extra_network in extra_network_registry.items(): args = extra_network_data.get(extra_network_name, None) if args is not None: continue try: extra_network.deactivate(p, force=force) except Exception as e: errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") re_extra_net = re.compile(r"<(\w+):([^>]+)>") def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNetworkParams]]]: res: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list) if prompt is None: return "", res if isinstance(prompt, list): shared.log.warning(f"parse_prompt was called with a list instead of a string: {prompt}") return parse_prompts(prompt) def found(m: re.Match[str]): name, args = m.group(1, 2) res[name].append(ExtraNetworkParams(items=args.split(":"))) return "" updated_prompt = re.sub(re_extra_net, found, prompt) return updated_prompt, res def parse_prompts(prompts: list[str]): updated_prompt_list: list[str] = [] extra_data: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list) for prompt in prompts: updated_prompt, parsed_extra_data = parse_prompt(prompt) if not extra_data: extra_data = parsed_extra_data updated_prompt_list.append(updated_prompt) return updated_prompt_list, extra_data ================================================ FILE: modules/extras.py ================================================ import os import html import json import time from PIL import Image import torch import gradio as gr import safetensors.torch from modules.merging import merge, merge_utils, modules_sdxl from modules import shared, images, sd_models, sd_vae, sd_samplers, devices def run_pnginfo(image): if image is None: return '', '', '' geninfo, items = images.read_info_from_image(image) items = {**{'parameters': geninfo}, **items} info = '' for key, text in items.items(): if key != 'UserComment': info += f"
{html.escape(str(key))}: {html.escape(str(text))}
" return '', geninfo, info def to_half(tensor, enable): if enable and tensor.dtype == torch.float: return tensor.half() return tensor def run_modelmerger(id_task, **kwargs): # pylint: disable=unused-argument jobid = shared.state.begin('Merge') t0 = time.time() def fail(message): shared.state.textinfo = message shared.state.end(jobid) return [*[gr.update() for _ in range(4)], message] kwargs["models"] = { "model_a": sd_models.get_closest_checkpoint_match(kwargs.get("primary_model_name", None)).filename, "model_b": sd_models.get_closest_checkpoint_match(kwargs.get("secondary_model_name", None)).filename, } if kwargs.get("primary_model_name", None) in [None, 'None']: return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.get_closest_checkpoint_match(kwargs.get("primary_model_name", None)) if kwargs.get("secondary_model_name", None) in [None, 'None']: return fail("Failed: Merging requires a secondary model.") secondary_model_info = sd_models.get_closest_checkpoint_match(kwargs.get("secondary_model_name", None)) if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in merge_utils.TRIPLE_METHODS: return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.") tertiary_model_info = sd_models.get_closest_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in merge_utils.TRIPLE_METHODS else None del kwargs["primary_model_name"] del kwargs["secondary_model_name"] if kwargs.get("tertiary_model_name", None) is not None: kwargs["models"] |= {"model_c": sd_models.get_closest_checkpoint_match(kwargs.get("tertiary_model_name", None)).filename} del kwargs["tertiary_model_name"] if kwargs.get("alpha_base", None) and kwargs.get("alpha_in_blocks", None) and kwargs.get("alpha_mid_block", None) and kwargs.get("alpha_out_blocks", None): try: alpha = [float(x) for x in [kwargs["alpha_base"]] + kwargs["alpha_in_blocks"].split(",") + [kwargs["alpha_mid_block"]] + kwargs["alpha_out_blocks"].split(",")] assert len(alpha) == 26 or len(alpha) == 20, "Alpha Block Weights are wrong length (26 or 20 for SDXL)" kwargs["alpha"] = alpha except KeyError as ke: shared.log.warning(f"Merge: Malformed manual block weight: {ke}") elif kwargs.get("alpha_preset", None) or kwargs.get("alpha", None): kwargs["alpha"] = kwargs.get("alpha_preset", kwargs["alpha"]) kwargs.pop("alpha_base", None) kwargs.pop("alpha_in_blocks", None) kwargs.pop("alpha_mid_block", None) kwargs.pop("alpha_out_blocks", None) kwargs.pop("alpha_preset", None) if kwargs.get("beta_base", None) and kwargs.get("beta_in_blocks", None) and kwargs.get("beta_mid_block", None) and kwargs.get("beta_out_blocks", None): try: beta = [float(x) for x in [kwargs["beta_base"]] + kwargs["beta_in_blocks"].split(",") + [kwargs["beta_mid_block"]] + kwargs["beta_out_blocks"].split(",")] assert len(beta) == 26 or len(beta) == 20, "Beta Block Weights are wrong length (26 or 20 for SDXL)" kwargs["beta"] = beta except KeyError as ke: shared.log.warning(f"Merge: Malformed manual block weight: {ke}") elif kwargs.get("beta_preset", None) or kwargs.get("beta", None): kwargs["beta"] = kwargs.get("beta_preset", kwargs["beta"]) kwargs.pop("beta_base", None) kwargs.pop("beta_in_blocks", None) kwargs.pop("beta_mid_block", None) kwargs.pop("beta_out_blocks", None) kwargs.pop("beta_preset", None) if kwargs["device"] == "gpu": kwargs["device"] = devices.device elif kwargs["device"] == "shuffle": kwargs["device"] = torch.device("cpu") kwargs["work_device"] = devices.device else: kwargs["device"] = torch.device("cpu") if kwargs.pop("unload", False): sd_models.unload_model_weights() try: theta_0 = merge.merge_models(**kwargs) except Exception as e: return fail(f"{e}") try: theta_0 = theta_0.to_dict() #TensorDict -> Dict if necessary except Exception: pass bake_in_vae_filename = sd_vae.vae_dict.get(kwargs.get("bake_in_vae", None), None) if bake_in_vae_filename is not None: shared.log.info(f"Merge VAE='{bake_in_vae_filename}'") shared.state.textinfo = 'Merge VAE' vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename) for key in vae_dict.keys(): theta_0_key = 'first_stage_model.' + key if theta_0_key in theta_0: theta_0[theta_0_key] = to_half(vae_dict[key], kwargs.get("precision", "fp16") == "fp16") del vae_dict ckpt_dir = shared.opts.ckpt_dir or sd_models.model_path filename = kwargs.get("custom_name", "Unnamed_Merge") filename += "." + kwargs.get("checkpoint_format", None) output_modelname = os.path.join(ckpt_dir, filename) shared.state.textinfo = "merge saving" metadata = None if kwargs.get("save_metadata", False): metadata = {"format": "pt", "sd_merge_models": {}} merge_recipe = { "type": "SDNext", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None, "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None, "merge_mode": kwargs.get('merge_mode', None), "alpha": kwargs.get('alpha', None), "beta": kwargs.get('beta', None), "precision": kwargs.get('precision', None), "custom_name": kwargs.get("custom_name", "Unamed_Merge"), } metadata["sd_merge_recipe"] = json.dumps(merge_recipe) def add_model_metadata(checkpoint_info): checkpoint_info.calculate_shorthash() metadata["sd_merge_models"][checkpoint_info.sha256] = { "name": checkpoint_info.name, "legacy_hash": checkpoint_info.hash, "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None) } metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {})) add_model_metadata(primary_model_info) if secondary_model_info: add_model_metadata(secondary_model_info) if tertiary_model_info: add_model_metadata(tertiary_model_info) metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"]) _, extension = os.path.splitext(output_modelname) if os.path.exists(output_modelname) and not kwargs.get("overwrite", False): return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"] if extension.lower() == ".safetensors": safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) else: torch.save(theta_0, output_modelname) t1 = time.time() shared.log.info(f"Merge complete: saved='{output_modelname}' time={t1-t0:.2f}") sd_models.list_models() created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None) if created_model: created_model.calculate_shorthash() devices.torch_gc(force=True, reason='merge') shared.state.end(jobid) return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model saved to {output_modelname}"] def run_model_modules(model_type:str, model_name:str, custom_name:str, comp_unet:str, comp_vae:str, comp_te1:str, comp_te2:str, precision:str, comp_scheduler:str, comp_prediction:str, comp_lora:str, comp_fuse:float, meta_author:str, meta_version:str, meta_license:str, meta_desc:str, meta_hint:str, meta_thumbnail:Image.Image, create_diffusers:bool, create_safetensors:bool, debug:bool): status = '' def msg(text, err:bool=False): nonlocal status if err: shared.log.error(f'Modules merge: {text}') else: shared.log.info(f'Modules merge: {text}') status += text + '
' return status if model_type != 'sdxl': yield msg("only SDXL models are supported", err=True) return if len(custom_name) == 0: yield msg("output name is required", err=True) return checkpoint_info = sd_models.get_closest_checkpoint_match(model_name) if checkpoint_info is None: yield msg("input model not found", err=True) return fn = checkpoint_info.filename jobid = shared.state.begin('Merge') yield msg("modules merge starting") yield msg("unload current model") sd_models.unload_model_weights(op='model') modules_sdxl.recipe.name = custom_name modules_sdxl.recipe.author = meta_author modules_sdxl.recipe.version = meta_version modules_sdxl.recipe.desc = meta_desc modules_sdxl.recipe.hint = meta_hint modules_sdxl.recipe.license = meta_license modules_sdxl.recipe.thumbnail = meta_thumbnail modules_sdxl.recipe.base = fn modules_sdxl.recipe.unet = comp_unet modules_sdxl.recipe.vae = comp_vae modules_sdxl.recipe.te1 = comp_te1 modules_sdxl.recipe.te2 = comp_te2 modules_sdxl.recipe.prediction = comp_prediction modules_sdxl.recipe.diffusers = create_diffusers modules_sdxl.recipe.safetensors = create_safetensors modules_sdxl.recipe.fuse = float(comp_fuse) modules_sdxl.recipe.debug = debug loras = [l.strip() if ':' in l else f'{l.strip()}:1.0' for l in comp_lora.split(',') if len(l.strip()) > 0] for lora, strength in [l.split(':') for l in loras]: modules_sdxl.recipe.lora[lora] = float(strength) scheduler = sd_samplers.create_sampler(comp_scheduler, None) modules_sdxl.recipe.scheduler = scheduler.__class__.__name__ if scheduler is not None else None if precision == 'fp32': modules_sdxl.recipe.precision = torch.float32 elif precision == 'bf16': modules_sdxl.recipe.precision = torch.bfloat16 else: modules_sdxl.recipe.precision = torch.float16 modules_sdxl.status = status yield from modules_sdxl.merge() status = modules_sdxl.status devices.torch_gc(force=True, reason='merge') yield msg("modules merge complete") if modules_sdxl.pipeline is not None: checkpoint_info = sd_models.CheckpointInfo(filename='None') shared.sd_model = modules_sdxl.pipeline sd_models.set_defaults(shared.sd_model, checkpoint_info) sd_models.set_diffuser_options(shared.sd_model, offload=False) sd_models.set_diffuser_offload(shared.sd_model) yield msg("pipeline loaded") shared.state.end(jobid) ================================================ FILE: modules/face/__init__.py ================================================ import os import gradio as gr from PIL import Image from modules import scripts_manager, processing, shared, images debug = shared.log.trace if os.environ.get('SD_FACE_DEBUG', None) is not None else lambda *args, **kwargs: None class Script(scripts_manager.Script): original_pipeline = None original_prompt_attention = None def title(self): return 'Face: Multiple ID Transfers' def show(self, is_img2img): return True def load_images(self, files): init_images = [] for file in files or []: try: if isinstance(file, str): from modules.api.api import decode_base64_to_image image = decode_base64_to_image(file) elif isinstance(file, Image.Image): image = file elif isinstance(file, dict) and 'name' in file: image = Image.open(file['name']) # _TemporaryFileWrapper from gr.Files elif hasattr(file, 'name'): image = Image.open(file.name) # _TemporaryFileWrapper from gr.Files else: raise ValueError(f'Face: unknown input: {file}') init_images.append(image) except Exception as e: shared.log.warning(f'Face: failed to load image: {e}') return init_images def mode_change(self, mode): return [ gr.update(visible=mode=='ReSwapper'), gr.update(visible=mode=='FaceID'), gr.update(visible=mode=='FaceSwap'), gr.update(visible=mode=='InstantID'), gr.update(visible=mode=='PhotoMaker'), ] # return signature is array of gradio components def ui(self, _is_img2img): with gr.Row(): gr.HTML("  Face: Multiple ID Transfers
") with gr.Row(): models = ['None', 'FaceID', 'FaceSwap', 'InstantID', 'PhotoMaker'] if shared.cmd_opts.experimental: models.append('ReSwapper') mode = gr.Dropdown(label='Mode', choices=['None', 'FaceID', 'FaceSwap', 'InstantID', 'PhotoMaker'], value='None') with gr.Group(visible=False) as cfg_reswapper: with gr.Row(): gr.HTML('  ReSwapper
') with gr.Row(): from modules.face.reswapper import RESWAPPER_MODELS reswapper_model = gr.Dropdown(choices=list(RESWAPPER_MODELS), label='ReSwapper Model', value='ReSwapper 256 0.2') reswapper_original = gr.Checkbox(label='Return original images', value=False) with gr.Group(visible=False) as cfg_faceid: with gr.Row(): gr.HTML('  Tencent AI Lab IP-Adapter FaceID
') with gr.Row(): from modules.face.faceid import FACEID_MODELS ip_model = gr.Dropdown(choices=list(FACEID_MODELS), label='FaceID Model', value='FaceID Base') with gr.Row(visible=True): ip_override = gr.Checkbox(label='Override sampler', value=True) ip_cache = gr.Checkbox(label='Cache model', value=True) with gr.Row(visible=True): ip_strength = gr.Slider(label='Strength', minimum=0.0, maximum=2.0, step=0.01, value=1.0) ip_structure = gr.Slider(label='Structure', minimum=0.0, maximum=1.0, step=0.01, value=1.0) with gr.Group(visible=False) as cfg_faceswap: with gr.Row(): gr.HTML('  InsightFace InSwapper
') with gr.Row(visible=True): fs_cache = gr.Checkbox(label='Cache model', value=True) with gr.Group(visible=False) as cfg_instantid: with gr.Row(): gr.HTML('  InstantX InstantID
') with gr.Row(): id_strength = gr.Slider(label='Strength', minimum=0.0, maximum=2.0, step=0.01, value=1.0) id_conditioning = gr.Slider(label='Control', minimum=0.0, maximum=2.0, step=0.01, value=0.5) with gr.Row(visible=True): id_cache = gr.Checkbox(label='Cache model', value=True) with gr.Group(visible=False) as cfg_photomaker: with gr.Row(): gr.HTML('  Tenecent ARC Lab PhotoMaker
') with gr.Row(): pm_model = gr.Dropdown(label='PhotoMaker Model', choices=['PhotoMaker v1', 'PhotoMaker v2'], value='PhotoMaker v2') pm_trigger = gr.Textbox(label='Trigger word', placeholder="enter one word in prompt") with gr.Row(): pm_strength = gr.Slider(label='Strength', minimum=0.0, maximum=2.0, step=0.01, value=1.0) pm_start = gr.Slider(label='Start', minimum=0.0, maximum=1.0, step=0.01, value=0.5) with gr.Row(): files = gr.File(label='Input images', file_count='multiple', file_types=['image'], interactive=True, height=100) with gr.Row(): gallery = gr.Gallery(show_label=False, value=[]) files.change(fn=self.load_images, inputs=[files], outputs=[gallery]) mode.change(fn=self.mode_change, inputs=[mode], outputs=[cfg_reswapper, cfg_faceid, cfg_faceswap, cfg_instantid, cfg_photomaker]) return [mode, gallery, reswapper_model, reswapper_original, ip_model, ip_override, ip_cache, ip_strength, ip_structure, id_strength, id_conditioning, id_cache, pm_model, pm_trigger, pm_strength, pm_start, fs_cache] def run(self, p: processing.StableDiffusionProcessing, mode, input_images, reswapper_model, reswapper_original, ip_model, ip_override, ip_cache, ip_strength, ip_structure, id_strength, id_conditioning, id_cache, pm_model, pm_trigger, pm_strength, pm_start, fs_cache): # pylint: disable=arguments-differ, unused-argument if mode == 'None': return None if input_images is None or len(input_images) == 0: shared.log.error('Face: no init images') return None if shared.sd_model_type != 'sd' and shared.sd_model_type != 'sdxl': shared.log.error('Face: base model not supported') return None input_images = input_images.copy() for i, image in enumerate(input_images): if isinstance(image, str): from modules.api.api import decode_base64_to_image input_images[i] = decode_base64_to_image(image).convert("RGB") for i, image in enumerate(input_images): if not isinstance(image, Image.Image): input_images[i] = Image.open(image['name']) processed = None self.original_pipeline = shared.sd_model self.original_prompt_attention = shared.opts.prompt_attention shared.opts.data['prompt_attention'] = 'fixed' if mode == 'FaceID': # faceid runs as ipadapter in its own pipeline from modules.face.insightface import get_app app = get_app('buffalo_l') from modules.face.faceid import face_id processed_images = face_id(p, app=app, source_images=input_images, model=ip_model, override=ip_override, cache=ip_cache, scale=ip_strength, structure=ip_structure) # run faceid pipeline processed = processing.get_processed(p, images_list=processed_images, seed=p.seed, subseed=p.subseed, index_of_first_image=0) # manually created processed object elif mode == 'PhotoMaker': # photomaker creates pipeline and triggers original process_images from modules.face.insightface import get_app app = get_app('buffalo_l') from modules.face.photomaker import photo_maker photo_maker(p, app=app, input_images=input_images, model=pm_model, trigger=pm_trigger, strength=pm_strength, start=pm_start) elif mode == 'InstantID': if hasattr(p, 'init_images') and p.init_images is not None and len(p.init_images) > 0: shared.log.warning('Face: InstantID with init image not supported') input_images += p.init_images from modules.face.insightface import get_app app=get_app('antelopev2') from modules.face.instantid import instant_id # instantid creates pipeline and triggers original process_images processed = instant_id(p, app=app, source_images=input_images, strength=id_strength, conditioning=id_conditioning, cache=id_cache) if processed is None: # run normal pipeline processed = processing.process_images(p) if mode == 'FaceSwap': # faceswap runs as postprocessing from modules.face.insightface import get_app app=get_app('buffalo_l') from modules.face.faceswap import face_swap processed.images = face_swap(p, app=app, input_images=processed.images, source_image=input_images[0], cache=fs_cache) elif mode == 'ReSwapper': from modules.face.insightface import get_app app = get_app('buffalo_l', resolution=512) from modules.face.reswapper import reswapper processed.images = reswapper(p, app=app, source_images=processed.images, target_images=input_images, model_name=reswapper_model, original=reswapper_original) processed.info = processed.infotext(p, 0) processed.infotexts = [processed.info] if shared.opts.samples_save and not p.do_not_save_samples and processed.images is not None: for i, image in enumerate(processed.images): info = processing.create_infotext(p, index=i) images.save_image(image, path=p.outpath_samples, seed=p.all_seeds[i], prompt=p.all_prompts[i], info=info, p=p) return processed def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=unused-argument if self.original_pipeline is not None: shared.sd_model = self.original_pipeline self.original_pipeline = None if self.original_prompt_attention is not None: shared.opts.data['prompt_attention'] = self.original_prompt_attention self.original_prompt_attention = None return processed ================================================ FILE: modules/face/faceid.py ================================================ from typing import List import os import cv2 import torch import numpy as np import diffusers import huggingface_hub as hf from PIL import Image from modules import processing, shared, devices, extra_networks, sd_hijack_freeu, script_callbacks, ipadapter, token_merge from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet FACEID_MODELS = { "FaceID Base": "h94/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin", "FaceID Plus v1": "h94/IP-Adapter-FaceID/ip-adapter-faceid-plus_sd15.bin", "FaceID Plus v2": "h94/IP-Adapter-FaceID/ip-adapter-faceid-plusv2_sd15.bin", "FaceID XL": "h94/IP-Adapter-FaceID/ip-adapter-faceid_sdxl.bin", # "FaceID Portrait v10": "h94/IP-Adapter-FaceID/ip-adapter-faceid-portrait_sd15.bin", # "FaceID Portrait v11": "h94/IP-Adapter-FaceID/ip-adapter-faceid-portrait-v11_sd15.bin", # "FaceID XL Plus v2": "h94/IP-Adapter-FaceID/ip-adapter-faceid_sdxl.bin", } faceid_model_weights = None faceid_model_name = None debug = shared.log.trace if os.environ.get("SD_FACE_DEBUG", None) is not None else lambda *args, **kwargs: None def hijack_load_ip_adapter(self): self.image_proj_model.load_state_dict(faceid_model_weights["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(faceid_model_weights["ip_adapter"], strict=False) def face_id( p: processing.StableDiffusionProcessing, app, source_images: List[Image.Image], model: str, override: bool, cache: bool, scale: float, structure: float, ): global faceid_model_weights, faceid_model_name # pylint: disable=global-statement if source_images is None or len(source_images) == 0: shared.log.warning('FaceID: no input images') return None from insightface.utils import face_align try: from ip_adapter.ip_adapter_faceid import ( IPAdapterFaceID, IPAdapterFaceIDPlus, IPAdapterFaceIDXL, IPAdapterFaceIDPlusXL, ) from ip_adapter.ip_adapter_faceid_separate import ( IPAdapterFaceID as IPAdapterFaceIDPortrait, ) except Exception as e: shared.log.error(f"FaceID incorrect version of ip_adapter: {e}") return None processed_images = [] faceid_model = None original_load_ip_adapter = None try: shared.prompt_styles.apply_styles_to_extra(p) if shared.opts.cuda_compile_backend == 'none': token_merge.apply_token_merging(p.sd_model) sd_hijack_freeu.apply_freeu(p) script_callbacks.before_process_callback(p) with context_hypertile_vae(p), context_hypertile_unet(p), devices.inference_context(): ip_ckpt = FACEID_MODELS[model] folder, filename = os.path.split(ip_ckpt) basename, _ext = os.path.splitext(filename) model_path = hf.hf_hub_download(repo_id=folder, filename=filename, cache_dir=shared.opts.hfcache_dir) if model_path is None: shared.log.error(f'FaceID download failed: model={model} file="{ip_ckpt}"') return None if faceid_model_weights is None or faceid_model_name != model or not cache: shared.log.debug(f'FaceID load: model={model} file="{ip_ckpt}"') faceid_model_weights = torch.load(model_path, map_location="cpu") else: shared.log.debug(f'FaceID cached: model={model} file="{ip_ckpt}"') if "XL Plus" in model and shared.sd_model_type == 'sd': image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" original_load_ip_adapter = IPAdapterFaceIDPlusXL.load_ip_adapter IPAdapterFaceIDPlusXL.load_ip_adapter = hijack_load_ip_adapter faceid_model = IPAdapterFaceIDPlusXL( sd_pipe=shared.sd_model, image_encoder_path=image_encoder_path, ip_ckpt=model_path, lora_rank=128, num_tokens=4, device=devices.device, torch_dtype=devices.dtype, ) elif "XL" in model and shared.sd_model_type == 'sdxl': original_load_ip_adapter = IPAdapterFaceIDXL.load_ip_adapter IPAdapterFaceIDXL.load_ip_adapter = hijack_load_ip_adapter faceid_model = IPAdapterFaceIDXL( sd_pipe=shared.sd_model, ip_ckpt=model_path, lora_rank=128, num_tokens=4, device=devices.device, torch_dtype=devices.dtype, ) elif "Plus" in model and shared.sd_model_type == 'sd': original_load_ip_adapter = IPAdapterFaceIDPlus.load_ip_adapter IPAdapterFaceIDPlus.load_ip_adapter = hijack_load_ip_adapter image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" faceid_model = IPAdapterFaceIDPlus( sd_pipe=shared.sd_model, image_encoder_path=image_encoder_path, ip_ckpt=model_path, lora_rank=128, num_tokens=4, device=devices.device, torch_dtype=devices.dtype, ) elif "Portrait" in model and shared.sd_model_type == 'sd': original_load_ip_adapter = IPAdapterFaceIDPortrait.load_ip_adapter IPAdapterFaceIDPortrait.load_ip_adapter = hijack_load_ip_adapter faceid_model = IPAdapterFaceIDPortrait( sd_pipe=shared.sd_model, ip_ckpt=model_path, num_tokens=16, n_cond=5, device=devices.device, torch_dtype=devices.dtype, ) elif "Base" in model and shared.sd_model_type == 'sd': original_load_ip_adapter = IPAdapterFaceID.load_ip_adapter IPAdapterFaceID.load_ip_adapter = hijack_load_ip_adapter faceid_model = IPAdapterFaceID( sd_pipe=shared.sd_model, ip_ckpt=model_path, lora_rank=128, num_tokens=4, device=devices.device, torch_dtype=devices.dtype, ) else: shared.log.error(f'FaceID model not supported: model="{model}" class={shared.sd_model.__class__.__name__}') return None if override: shared.sd_model.scheduler = diffusers.DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) shortcut = "v2" in model faceid_model_name = model face_embeds = [] face_images = [] for i, source_image in enumerate(source_images): np_image = cv2.cvtColor(np.array(source_image), cv2.COLOR_RGB2BGR) faces = app.get(np_image) if len(faces) == 0: shared.log.error("FaceID: no faces found") break face_embeds.append(torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)) face_images.append(face_align.norm_crop(np_image, landmark=faces[0].kps, image_size=224)) shared.log.debug(f'FaceID face: i={i+1} score={faces[0].det_score:.2f} gender={"female" if faces[0].gender==0 else "male"} age={faces[0].age} bbox={faces[0].bbox}') p.extra_generation_params[f"FaceID {i+1}"] = f'{faces[0].det_score:.2f} {"female" if faces[0].gender==0 else "male"} {faces[0].age}y' if len(face_embeds) == 0: shared.log.error("FaceID: no faces found") return None face_embeds = torch.cat(face_embeds, dim=0) ip_model_dict = { # main generate dict "num_samples": p.batch_size, "width": p.width, "height": p.height, "num_inference_steps": p.steps, "scale": scale, "guidance_scale": p.cfg_scale, "faceid_embeds": face_embeds.shape, # placeholder } # optional generate dict if shortcut is not None: ip_model_dict["shortcut"] = shortcut if "Plus" in model: ip_model_dict["s_scale"] = structure shared.log.debug(f"FaceID args: {ip_model_dict}") if "Plus" in model: ip_model_dict["face_image"] = face_images ip_model_dict["faceid_embeds"] = face_embeds # overwrite placeholder faceid_model.set_scale(scale) if not p.all_prompts: processing.process_init(p) p.init(p.all_prompts, p.all_seeds, p.all_subseeds) for n in range(p.n_iter): p.iteration = n p.prompts = p.all_prompts[n * p.batch_size:(n+1) * p.batch_size] p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n+1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n+1) * p.batch_size] p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts) extra_networks.activate(p, p.network_data) ip_model_dict.update({ "prompt": p.prompts[0], "negative_prompt": p.negative_prompts[0], "seed": p.seeds[0], }) debug(f"FaceID: {ip_model_dict}") res = faceid_model.generate(**ip_model_dict) if isinstance(res, list): processed_images += res faceid_model.set_scale(0) faceid_model = None if not cache: faceid_model_weights = None faceid_model_name = None devices.torch_gc() ipadapter.unapply(p.sd_model) extra_networks.deactivate(p, p.network_data) p.extra_generation_params["IP Adapter"] = f"{basename}:{scale}" finally: if faceid_model is not None and original_load_ip_adapter is not None: faceid_model.__class__.load_ip_adapter = original_load_ip_adapter if shared.opts.cuda_compile_backend == 'none': token_merge.remove_token_merging(p.sd_model) script_callbacks.after_process_callback(p) return processed_images ================================================ FILE: modules/face/faceswap.py ================================================ from typing import List import os import cv2 import numpy as np import huggingface_hub as hf from PIL import Image from modules import processing, shared, devices debug = shared.log.trace if os.environ.get('SD_FACE_DEBUG', None) is not None else lambda *args, **kwargs: None insightface_app = None swapper = None def face_swap(p: processing.StableDiffusionProcessing, app, input_images: List[Image.Image], source_image: Image.Image, cache: bool): global swapper # pylint: disable=global-statement if swapper is None: import insightface.model_zoo repo_id = 'ezioruan/inswapper_128.onnx' model_path = hf.hf_hub_download(repo_id=repo_id, filename='inswapper_128.onnx', cache_dir=shared.opts.hfcache_dir) shared.log.debug(f'FaceSwap load: repo="{repo_id}" path="{model_path}"') # model_path = hf.hf_hub_download(repo_id='somanchiu/reswapper', filename='reswapper_256-1567500_originalInswapperClassCompatible.onnx', cache_dir=shared.opts.hfcache_dir) try: router: insightface.model_zoo.model_zoo.INSwapper = insightface.model_zoo.model_zoo.ModelRouter(model_path) swapper = router.get_model() except Exception as e: shared.log.error(f'FaceSwap load: {e}') return None np_image = cv2.cvtColor(np.array(source_image), cv2.COLOR_RGB2BGR) faces = app.get(np_image) if faces is None or len(faces) == 0: shared.log.warning('FaceSwap: No faces detected') return None source_face = faces[0] processed_images = [] for image in input_images: np_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) faces = app.get(np_image) for i, face in enumerate(faces): debug(f'FaceSwap: face={i} source={source_face.bbox} target={face.bbox}') np_image = swapper.get(img=np_image, target_face=face, source_face=source_face, paste_back=True) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter p.extra_generation_params["FaceSwap"] = f'{len(faces)}' np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB) processed_images.append(Image.fromarray(np_image)) if not cache: swapper = None devices.torch_gc() return processed_images ================================================ FILE: modules/face/insightface.py ================================================ import os from modules.shared import log, opts from modules import devices insightface_app = None instightface_mp = None def get_app(mp_name, threshold=0.5, resolution=640): global insightface_app, instightface_mp # pylint: disable=global-statement from installer import install, installed, install_insightface if not installed('insightface', reload=False, quiet=True): install_insightface() if not installed('ip_adapter', reload=False, quiet=True): install('git+https://github.com/tencent-ailab/IP-Adapter.git', 'ip_adapter', ignore=False) if insightface_app is None or mp_name != instightface_mp: import insightface from insightface.model_zoo import model_zoo from insightface.app import face_analysis model_zoo.print = lambda *args, **kwargs: None face_analysis.print = lambda *args, **kwargs: None import huggingface_hub as hf import zipfile log.debug(f"InsightFace: version={insightface.__version__} mp={mp_name} provider={devices.onnx}") root_dir = os.path.join(opts.diffusers_dir, 'models--vladmandic--insightface-faceanalysis') local_dir = os.path.join(root_dir, 'models') extract_dir = os.path.join(local_dir, mp_name) model_path = os.path.join(local_dir, f'{mp_name}.zip') if not os.path.exists(model_path): model_path = hf.hf_hub_download( repo_id='vladmandic/insightface-faceanalysis', filename=f'{mp_name}.zip', local_dir_use_symlinks=False, cache_dir=opts.hfcache_dir, local_dir=local_dir ) if not os.path.exists(extract_dir): log.debug(f'InsightFace extract: folder="{extract_dir}"') os.makedirs(extract_dir) with zipfile.ZipFile(model_path) as zf: zf.extractall(local_dir) kwargs = { 'root': root_dir, 'download': False, 'download_zip': False, } insightface_app = face_analysis.FaceAnalysis(name=mp_name, providers=devices.onnx, **kwargs) instightface_mp = mp_name insightface_app.prepare(ctx_id=0, det_thresh=threshold, det_size=(resolution, resolution)) return insightface_app ================================================ FILE: modules/face/instantid.py ================================================ import os import cv2 import torch import numpy as np import huggingface_hub as hf from modules import shared, processing, sd_models, devices REPO_ID = "InstantX/InstantID" controlnet_model = None debug = shared.log.trace if os.environ.get('SD_FACE_DEBUG', None) is not None else lambda *args, **kwargs: None def instant_id(p: processing.StableDiffusionProcessing, app, source_images, strength=1.0, conditioning=0.5, cache=True): # pylint: disable=arguments-differ from modules.face.instantid_model import StableDiffusionXLInstantIDPipeline, draw_kps from diffusers.models import ControlNetModel global controlnet_model # pylint: disable=global-statement # prepare pipeline if source_images is None or len(source_images) == 0: shared.log.warning('InstantID: no input images') return None c = shared.sd_model.__class__.__name__ if shared.sd_loaded else '' if c not in ['StableDiffusionXLPipeline', 'StableDiffusionXLInstantIDPipeline']: shared.log.warning(f'InstantID invalid base model: current={c} required=StableDiffusionXLPipeline') return None # prepare face emb face_embeds = [] face_images = [] for i, source_image in enumerate(source_images): faces = app.get(cv2.cvtColor(np.array(source_image), cv2.COLOR_RGB2BGR)) face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face face_embeds.append(torch.from_numpy(face['embedding'])) face_images.append(draw_kps(source_image, face['kps'])) p.extra_generation_params[f"InstantID {i+1}"] = f'{faces[0].det_score:.2f} {"female" if faces[0].gender==0 else "male"} {faces[0].age}y' shared.log.debug(f'InstantID face: score={face.det_score:.2f} gender={"female" if face.gender==0 else "male"} age={face.age} bbox={face.bbox}') shared.log.debug(f'InstantID loading: model={REPO_ID}') face_adapter = hf.hf_hub_download(repo_id=REPO_ID, filename="ip-adapter.bin") if controlnet_model is None or not cache: controlnet_model = ControlNetModel.from_pretrained(REPO_ID, subfolder="ControlNetModel", torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir) sd_models.move_model(controlnet_model, devices.device) # create new pipeline orig_pipeline = shared.sd_model # backup current pipeline definition shared.sd_model = StableDiffusionXLInstantIDPipeline( vae = shared.sd_model.vae, text_encoder=shared.sd_model.text_encoder, text_encoder_2=shared.sd_model.text_encoder_2, tokenizer=shared.sd_model.tokenizer, tokenizer_2=shared.sd_model.tokenizer_2, unet=shared.sd_model.unet, scheduler=shared.sd_model.scheduler, controlnet=controlnet_model, force_zeros_for_empty_prompt=shared.opts.diffusers_force_zeros, ) sd_models.copy_diffuser_options(shared.sd_model, orig_pipeline) # copy options from original pipeline sd_models.set_diffuser_options(shared.sd_model) # set all model options such as fp16, offload, etc. shared.sd_model.load_ip_adapter_instantid(face_adapter, scale=strength) shared.sd_model.set_ip_adapter_scale(strength) sd_models.move_model(shared.sd_model, devices.device) # move pipeline to device # pipeline specific args if not p.all_prompts: processing.process_init(p) p.init(p.all_prompts, p.all_seeds, p.all_subseeds) orig_prompt_attention = shared.opts.prompt_attention shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask p.task_args['image_embeds'] = face_embeds[0].shape # placeholder p.task_args['image'] = face_images[0] p.task_args['controlnet_conditioning_scale'] = float(conditioning) p.task_args['ip_adapter_scale'] = float(strength) shared.log.debug(f"InstantID args: {p.task_args}") p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts else p.prompt p.task_args['negative_prompt'] = p.all_negative_prompts[0] if p.all_negative_prompts else p.negative_prompt p.task_args['image_embeds'] = face_embeds[0] # overwrite placeholder # run processing processed: processing.Processed = processing.process_images(p) shared.sd_model.set_ip_adapter_scale(0) p.extra_generation_params['InstantID'] = f'{strength}/{conditioning}' if not cache: controlnet_model = None devices.torch_gc() # restore original pipeline shared.opts.data['prompt_attention'] = orig_prompt_attention shared.sd_model = orig_pipeline return processed ================================================ FILE: modules/face/instantid_model.py ================================================ # Copyright 2024 The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cv2 import numpy as np import PIL.Image import torch import torch.nn as nn from diffusers import StableDiffusionXLControlNetPipeline from diffusers.image_processor import PipelineImageInput from diffusers.models import ControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.utils import ( deprecate, logging, replace_example_docstring, ) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module, is_torch_version try: import xformers import xformers.ops xformers_available = True except Exception: xformers_available = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) def reshape_tensor(x, heads): bs, length, _width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, heads, length, -1) return x class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class Resampler(nn.Module): def __init__( self, dim=1024, depth=8, dim_head=64, heads=16, num_queries=8, embedding_dim=768, output_dim=1024, ff_mult=4, ): super().__init__() self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, x): latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAttnProcessor(nn.Module): r""" Attention processor for IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) if xformers_available: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) if xformers_available: ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) else: ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): query = query.contiguous() key = key.contiguous() value = value.contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) return hidden_states EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate insightface >>> import diffusers >>> from diffusers.utils import load_image >>> from diffusers.models import ControlNetModel >>> import cv2 >>> import torch >>> import numpy as np >>> from PIL import Image >>> from insightface.app import FaceAnalysis >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps >>> # download 'antelopev2' under ./models >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) >>> app.prepare(ctx_id=0, det_size=(640, 640)) >>> # download models under ./checkpoints >>> face_adapter = f'./checkpoints/ip-adapter.bin' >>> controlnet_path = f'./checkpoints/ControlNetModel' >>> # load IdentityNet >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.cuda() >>> # load adapter >>> pipe.load_ip_adapter_instantid(face_adapter) >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" >>> # load an image >>> image = load_image("your-example.jpg") >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] >>> face_emb = face_info['embedding'] >>> face_kps = draw_kps(face_image, face_info['kps']) >>> pipe.set_ip_adapter_scale(0.8) >>> # generate image >>> image = pipe( ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 ... ).images[0] ``` """ def draw_kps(image_pil, kps, color_list=None): if color_list is None: color_list = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)] stickwidth = 4 limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) kps = np.array(kps) w, h = image_pil.size out_img = np.zeros([h, w, 3]) for i in range(len(limbSeq)): index = limbSeq[i] color = color_list[index[0]] x = kps[index][:, 0] y = kps[index][:, 1] length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) polygon = cv2.ellipse2Poly( (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 ) out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) out_img = (out_img * 0.6).astype(np.uint8) for idx_kp, kp in enumerate(kps): color = color_list[idx_kp] x, y = kp out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) return out_img_pil class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): def cuda(self, dtype=torch.float16, use_xformers=False): self.to("cuda", dtype) if hasattr(self, "image_proj_model"): self.image_proj_model.to(self.unet.device).to(self.unet.dtype) if use_xformers: if is_xformers_available(): import xformers from packaging import version xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) self.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) self.set_ip_adapter(model_ckpt, num_tokens, scale) def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): image_proj_model = Resampler( dim=1280, depth=4, dim_head=64, heads=20, num_queries=num_tokens, embedding_dim=image_emb_dim, output_dim=self.unet.config.cross_attention_dim, ff_mult=4, ) image_proj_model.eval() self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) state_dict = torch.load(model_ckpt, map_location="cpu") if "image_proj" in state_dict: state_dict = state_dict["image_proj"] self.image_proj_model.load_state_dict(state_dict) self.image_proj_model_in_features = image_emb_dim def set_ip_adapter(self, model_ckpt, num_tokens, scale): unet = self.unet attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) else: attn_procs[name] = IPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=scale, num_tokens=num_tokens, ).to(unet.device, dtype=unet.dtype) unet.set_attn_processor(attn_procs) state_dict = torch.load(model_ckpt, map_location="cpu") ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) if "ip_adapter" in state_dict: state_dict = state_dict["ip_adapter"] ip_layers.load_state_dict(state_dict) def set_ip_adapter_scale(self, scale): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance): if isinstance(prompt_image_emb, torch.Tensor): prompt_image_emb = prompt_image_emb.clone().detach() else: prompt_image_emb = torch.tensor(prompt_image_emb) prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) if do_classifier_free_guidance: prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) else: prompt_image_emb = torch.cat([prompt_image_emb], dim=0) prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype) self.image_proj_model = self.image_proj_model.to(device=device, dtype=dtype) prompt_image_emb = self.image_proj_model(prompt_image_emb) return prompt_image_emb @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, image_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = None, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled text embeddings are generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. image_embeds (`torch.FloatTensor`, *optional*): Pre-generated image embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeine class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ if callback_on_step_end_tensor_inputs is None: callback_on_step_end_tensor_inputs = ["latents"] callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct """ self.check_inputs( prompt, prompt_2, image, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) """ self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 3.2 Encode image prompt prompt_image_emb = self._encode_prompt_image_emb( image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: raise AssertionError # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): original_size = original_size or image[0].shape[-2:] else: original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1], } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=prompt_image_emb, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=encoder_hidden_states, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: modules/face/photomaker.py ================================================ import cv2 import numpy as np import torch import huggingface_hub as hf from modules import shared, processing, sd_models, devices original_pipeline = None def restore_pipeline(): global original_pipeline # pylint: disable=global-statement if original_pipeline is not None: shared.sd_model = original_pipeline original_pipeline = None def photo_maker(p: processing.StableDiffusionProcessing, app, model: str, input_images, trigger, strength, start): # pylint: disable=arguments-differ global original_pipeline # pylint: disable=global-statement from modules.face.photomaker_pipeline import PhotoMakerStableDiffusionXLPipeline # prepare pipeline if len(input_images) == 0: shared.log.warning('PhotoMaker: no input images') return None if len(trigger) == 0: shared.log.warning('PhotoMaker: no trigger word') return None c = shared.sd_model.__class__.__name__ if shared.sd_loaded else '' if c != 'StableDiffusionXLPipeline': shared.log.warning(f'PhotoMaker invalid base model: current={c} required=StableDiffusionXLPipeline') return None # validate prompt if not p.all_prompts: processing.process_init(p) p.init(p.all_prompts, p.all_seeds, p.all_subseeds) trigger_ids = shared.sd_model.tokenizer.encode(trigger) + shared.sd_model.tokenizer_2.encode(trigger) prompt_ids1 = shared.sd_model.tokenizer.encode(p.all_prompts[0]) prompt_ids2 = shared.sd_model.tokenizer_2.encode(p.all_prompts[0]) for t in trigger_ids: if prompt_ids1.count(t) != 1: shared.log.error(f'PhotoMaker: trigger word not matched in prompt: {trigger} ids={trigger_ids} prompt={p.all_prompts[0]} ids={prompt_ids1}') return None if prompt_ids2.count(t) != 1: shared.log.error(f'PhotoMaker: trigger word not matched in prompt: {trigger} ids={trigger_ids} prompt={p.all_prompts[0]} ids={prompt_ids1}') return None # create new pipeline original_pipeline = shared.sd_model # backup current pipeline definition # orig_pipeline = shared.sd_model # backup current pipeline definition shared.sd_model = sd_models.switch_pipe(PhotoMakerStableDiffusionXLPipeline, shared.sd_model) shared.sd_model.restore_pipeline = restore_pipeline # sd_models.copy_diffuser_options(shared.sd_model, orig_pipeline) # copy options from original pipeline sd_models.set_diffuser_options(shared.sd_model) # set all model options such as fp16, offload, etc. sd_models.apply_balanced_offload(shared.sd_model) # apply balanced offload orig_prompt_attention = shared.opts.prompt_attention shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask p.task_args['input_id_images'] = input_images p.task_args['start_merge_step'] = int(start * p.steps) p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts else p.prompt is_v2 = 'v2' in model if is_v2: repo_id, fn = 'TencentARC/PhotoMaker-V2', 'photomaker-v2.bin' else: repo_id, fn = 'TencentARC/PhotoMaker', 'photomaker-v1.bin' photomaker_path = hf.hf_hub_download(repo_id=repo_id, filename=fn, repo_type="model", cache_dir=shared.opts.hfcache_dir) shared.log.debug(f'PhotoMaker: model="{model}" uri="{repo_id}/{fn}" images={len(input_images)} trigger={trigger} args={p.task_args}') # load photomaker adapter shared.sd_model.load_photomaker_adapter( photomaker_path, trigger_word=trigger, weight_name='photomaker-v2.bin' if is_v2 else 'photomaker-v1.bin', pm_version='v2' if is_v2 else 'v1', device=devices.device, cache_dir=shared.opts.hfcache_dir, ) shared.sd_model.set_adapters(["photomaker"], adapter_weights=[strength]) # analyze faces if is_v2: id_embed_list = [] for i, source_image in enumerate(input_images): faces = app.get(cv2.cvtColor(np.array(source_image), cv2.COLOR_RGB2BGR)) face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face id_embed_list.append(torch.from_numpy(face['embedding'])) shared.log.debug(f'PhotoMaker: face={i+1} score={face.det_score:.2f} gender={"female" if face.gender==0 else "male"} age={face.age} bbox={face.bbox}') p.task_args['id_embeds'] = torch.stack(id_embed_list).to(device=devices.device, dtype=devices.dtype) # run processing # processed: processing.Processed = processing.process_images(p) p.extra_generation_params['PhotoMaker'] = f'{strength}' # unload photomaker adapter shared.sd_model.unload_lora_weights() # restore original pipeline shared.opts.data['prompt_attention'] = orig_prompt_attention # shared.sd_model = orig_pipeline return None # return processed ================================================ FILE: modules/face/photomaker_model_v1.py ================================================ ### original import torch import torch.nn as nn from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection from transformers.models.clip.configuration_clip import CLIPVisionConfig VISION_CONFIG_DICT = { "hidden_size": 1024, "intermediate_size": 4096, "num_attention_heads": 16, "num_hidden_layers": 24, "patch_size": 14, "projection_dim": 768 } class MLP(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): super().__init__() if use_residual: assert in_dim == out_dim self.layernorm = nn.LayerNorm(in_dim) self.fc1 = nn.Linear(in_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, out_dim) self.use_residual = use_residual self.act_fn = nn.GELU() def forward(self, x): residual = x x = self.layernorm(x) x = self.fc1(x) x = self.act_fn(x) x = self.fc2(x) if self.use_residual: x = x + residual return x class FuseModule(nn.Module): def __init__(self, embed_dim): super().__init__() self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) self.layer_norm = nn.LayerNorm(embed_dim) def fuse_fn(self, prompt_embeds, id_embeds): unstacked_prompt_embeds = prompt_embeds.unbind(0) stacked_id_embeds = torch.cat([unstacked_prompt_embeds[0].unsqueeze(0), id_embeds], dim=-1) # monkey patch stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds stacked_id_embeds = self.mlp2(stacked_id_embeds) stacked_id_embeds = self.layer_norm(stacked_id_embeds) return stacked_id_embeds def forward( self, prompt_embeds, id_embeds, class_tokens_mask, ) -> torch.Tensor: # id_embeds shape: [b, max_num_inputs, 1, 2048] id_embeds = id_embeds.to(prompt_embeds.dtype) num_inputs = class_tokens_mask.sum().unsqueeze(0) batch_size, max_num_inputs = id_embeds.shape[:2] # seq_length: 77 seq_length = prompt_embeds.shape[1] # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] flat_id_embeds = id_embeds.view( -1, id_embeds.shape[-2], id_embeds.shape[-1] ) # valid_id_mask [b*max_num_inputs] valid_id_mask = ( torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] < num_inputs[:, None] ) valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) class_tokens_mask = class_tokens_mask.view(-1) valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # slice out the image token embeddings image_token_embeds = prompt_embeds[class_tokens_mask] stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) return updated_prompt_embeds class PhotoMakerIDEncoder(CLIPVisionModelWithProjection): def __init__(self): super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT)) self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) self.fuse_module = FuseModule(2048) def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): # pylint: disable=arguments-differ b, num_inputs, c, h, w = id_pixel_values.shape id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) shared_id_embeds = self.vision_model(id_pixel_values)[1] id_embeds = self.visual_projection(shared_id_embeds) id_embeds_2 = self.visual_projection_2(shared_id_embeds) id_embeds = id_embeds.view(b, num_inputs, 1, -1) id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) return updated_prompt_embeds ================================================ FILE: modules/face/photomaker_model_v2.py ================================================ ### original import math import torch import torch.nn as nn from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection from transformers.models.clip.configuration_clip import CLIPVisionConfig from einops import rearrange from einops.layers.torch import Rearrange class FacePerceiverResampler(torch.nn.Module): def __init__( self, *, dim=768, depth=4, dim_head=64, heads=16, embedding_dim=1280, output_dim=768, ff_mult=4, ): super().__init__() self.proj_in = torch.nn.Linear(embedding_dim, dim) self.proj_out = torch.nn.Linear(dim, output_dim) self.norm_out = torch.nn.LayerNorm(output_dim) self.layers = torch.nn.ModuleList([]) for _ in range(depth): self.layers.append( torch.nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, latents, x): x = self.proj_in(x) for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) # FFN def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) def reshape_tensor(x, heads): bs, length, _width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, heads, length, -1) return x class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class Resampler(nn.Module): def __init__( self, dim=1024, depth=8, dim_head=64, heads=16, num_queries=8, embedding_dim=768, output_dim=1024, ff_mult=4, max_seq_len: int = 257, # CLIP tokens + CLS token apply_pos_emb: bool = False, num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence ): super().__init__() self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) self.to_latents_from_mean_pooled_seq = ( nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * num_latents_mean_pooled), Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), ) if num_latents_mean_pooled > 0 else None ) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, x): if self.pos_emb is not None: n, device = x.shape[1], x.device pos_emb = self.pos_emb(torch.arange(n, device=device)) x = x + pos_emb latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) if self.to_latents_from_mean_pooled_seq: meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) latents = torch.cat((meanpooled_latents, latents), dim=-2) for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) def masked_mean(t, *, dim, mask=None): if mask is None: return t.mean(dim=dim) denom = mask.sum(dim=dim, keepdim=True) mask = rearrange(mask, "b n -> b n 1") masked_t = t.masked_fill(~mask, 0.0) return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) VISION_CONFIG_DICT = { "hidden_size": 1024, "intermediate_size": 4096, "num_attention_heads": 16, "num_hidden_layers": 24, "patch_size": 14, "projection_dim": 768 } class MLP(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): super().__init__() if use_residual: assert in_dim == out_dim self.layernorm = nn.LayerNorm(in_dim) self.fc1 = nn.Linear(in_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, out_dim) self.use_residual = use_residual self.act_fn = nn.GELU() def forward(self, x): residual = x x = self.layernorm(x) x = self.fc1(x) x = self.act_fn(x) x = self.fc2(x) if self.use_residual: x = x + residual return x class QFormerPerceiver(nn.Module): def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4): super().__init__() self.num_tokens = num_tokens self.cross_attention_dim = cross_attention_dim self.use_residual = use_residual self.token_proj = nn.Sequential( nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio), nn.GELU(), nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens), ) self.token_norm = nn.LayerNorm(cross_attention_dim) self.perceiver_resampler = FacePerceiverResampler( dim=cross_attention_dim, depth=4, dim_head=128, heads=cross_attention_dim // 128, embedding_dim=embedding_dim, output_dim=cross_attention_dim, ff_mult=4, ) def forward(self, x, last_hidden_state): x = self.token_proj(x) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) x = self.token_norm(x) # cls token out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens if self.use_residual: out = x + 1.0 * out return out class FuseModule(nn.Module): def __init__(self, embed_dim): super().__init__() self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) self.layer_norm = nn.LayerNorm(embed_dim) def fuse_fn(self, prompt_embeds, id_embeds): stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds stacked_id_embeds = self.mlp2(stacked_id_embeds) stacked_id_embeds = self.layer_norm(stacked_id_embeds) return stacked_id_embeds def forward( self, prompt_embeds, id_embeds, class_tokens_mask, ) -> torch.Tensor: # id_embeds shape: [b, max_num_inputs, 1, 2048] id_embeds = id_embeds.to(prompt_embeds.dtype) num_inputs = class_tokens_mask.sum().unsqueeze(0) batch_size, max_num_inputs = id_embeds.shape[:2] # seq_length: 77 seq_length = prompt_embeds.shape[1] # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] flat_id_embeds = id_embeds.view( -1, id_embeds.shape[-2], id_embeds.shape[-1] ) # valid_id_mask [b*max_num_inputs] valid_id_mask = ( torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] < num_inputs[:, None] ) valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) class_tokens_mask = class_tokens_mask.view(-1) valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # slice out the image token embeddings image_token_embeds = prompt_embeds[class_tokens_mask] stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) return updated_prompt_embeds class PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken(CLIPVisionModelWithProjection): def __init__(self, id_embeddings_dim=512): super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT)) self.fuse_module = FuseModule(2048) self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) cross_attention_dim = 2048 # projection self.num_tokens = 2 self.cross_attention_dim = cross_attention_dim self.qformer_perceiver = QFormerPerceiver( id_embeddings_dim, cross_attention_dim, self.num_tokens, ) def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds): # pylint: disable=arguments-differ, arguments-renamed b, num_inputs, c, h, w = id_pixel_values.shape id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) last_hidden_state = self.vision_model(id_pixel_values)[0] id_embeds = id_embeds.view(b * num_inputs, -1) id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state) id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1) updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) return updated_prompt_embeds ================================================ FILE: modules/face/photomaker_pipeline.py ================================================ ### original import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL import torch from transformers import CLIPImageProcessor from safetensors import safe_open from huggingface_hub.utils import validate_hf_hub_args from diffusers import StableDiffusionXLPipeline from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.utils import _get_model_file, USE_PEFT_BACKEND, deprecate, is_torch_xla_available, scale_lora_layers, unscale_lora_layers if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False from modules.face.photomaker_model_v1 import PhotoMakerIDEncoder from modules.face.photomaker_model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken PipelineImageInput = Union[ PIL.Image.Image, torch.FloatTensor, List[PIL.Image.Image], List[torch.FloatTensor], ] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): @validate_hf_hub_args def load_photomaker_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], weight_name: str, subfolder: str = '', trigger_word: str = 'img', pm_version: str = 'v2', device: torch.device = None, **kwargs, ): """ Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). weight_name (`str`): The weight name NOT the path to the weight. subfolder (`str`, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. trigger_word (`str`, *optional*, defaults to `"img"`): The trigger word is used to identify the position of class word in the text prompt, and it is recommended not to set it as a common word. This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. """ # Load the main state dict first. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } if not isinstance(pretrained_model_name_or_path_or_dict, dict): model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) if weight_name.endswith(".safetensors"): state_dict = {"id_encoder": {}, "lora_weights": {}} with safe_open(model_file, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("id_encoder."): state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) elif key.startswith("lora_weights."): state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) else: state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict keys = list(state_dict.keys()) if keys != ["id_encoder", "lora_weights"]: raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") self.num_tokens =2 # pylint: disable=attribute-defined-outside-init self.pm_version = pm_version # pylint: disable=attribute-defined-outside-init self.trigger_word = trigger_word # pylint: disable=attribute-defined-outside-init # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet self.id_image_processor = CLIPImageProcessor() # pylint: disable=attribute-defined-outside-init if pm_version == "v1": # PhotoMaker v1 id_encoder = PhotoMakerIDEncoder() elif pm_version == "v2": # PhotoMaker v2 id_encoder = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() else: raise NotImplementedError(f"The PhotoMaker version [{pm_version}] does not support") id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) id_encoder = id_encoder.to(device, dtype=self.unet.dtype) self.id_encoder = id_encoder # pylint: disable=attribute-defined-outside-init # load lora into models self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") # Add trigger word token if self.tokenizer is not None: self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) def encode_prompt_with_trigger_word( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ### Added args num_id_images: int = 1, class_tokens_mask: Optional[torch.LongTensor] = None, ): device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # pylint: disable=attribute-defined-outside-init # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Find the token id of the trigger word image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): _removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) clean_index = 0 clean_input_ids = [] class_token_index = [] # Find out the corresponding class word token based on the newly added trigger word token for _i, token_id in enumerate(text_input_ids.tolist()[0]): if token_id == image_token_id: class_token_index.append(clean_index - 1) else: clean_input_ids.append(token_id) clean_index += 1 if len(class_token_index) != 1: raise ValueError( f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ Trigger word: {self.trigger_word}, Prompt: {prompt}." ) class_token_index = class_token_index[0] # Expand the class word token and corresponding mask class_token = clean_input_ids[class_token_index] clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images * self.num_tokens + \ clean_input_ids[class_token_index+1:] # Truncation or padding max_len = tokenizer.model_max_length if len(clean_input_ids) > max_len: clean_input_ids = clean_input_ids[:max_len] else: clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( max_len - len(clean_input_ids) ) class_tokens_mask = [True if class_token_index <= i < class_token_index+(num_id_images * self.num_tokens) else False \ for i in range(len(clean_input_ids))] clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) prompt_embeds = text_encoder(clean_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) class_tokens_mask = class_tokens_mask.to(device=device) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt # pylint: disable=no-member if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): # pylint: disable=redefined-argument-from-local if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, class_tokens_mask @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], # Added parameters (for PhotoMaker) input_id_images: PipelineImageInput = None, start_merge_step: int = 10, class_tokens_mask: Optional[torch.LongTensor] = None, id_embeds: Optional[torch.FloatTensor] = None, prompt_embeds_text_only: Optional[torch.FloatTensor] = None, pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Only the parameters introduced by PhotoMaker are discussed here. For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py Args: input_id_images (`PipelineImageInput`, *optional*): Input ID Image to work with PhotoMaker. class_tokens_mask (`torch.LongTensor`, *optional*): Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. prompt_embeds_text_only (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale # pylint: disable=attribute-defined-outside-init self._guidance_rescale = guidance_rescale # pylint: disable=attribute-defined-outside-init self._clip_skip = clip_skip # pylint: disable=attribute-defined-outside-init self._cross_attention_kwargs = cross_attention_kwargs # pylint: disable=attribute-defined-outside-init self._denoising_end = denoising_end # pylint: disable=attribute-defined-outside-init self._interrupt = False # pylint: disable=attribute-defined-outside-init if prompt_embeds is not None and class_tokens_mask is None: raise ValueError( "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." ) # check the input id images if input_id_images is None: raise ValueError( "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." ) if not isinstance(input_id_images, list): input_id_images = [input_id_images] # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) num_id_images = len(input_id_images) ( prompt_embeds, _, pooled_prompt_embeds, _, class_tokens_mask, ) = self.encode_prompt_with_trigger_word( prompt=prompt, prompt_2=prompt_2, device=device, num_id_images=num_id_images, class_tokens_mask=class_tokens_mask, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Encode input prompt without the trigger word for delayed conditioning # encode, remove trigger word token, then decode tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False) trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word) tokens_text_only.remove(trigger_word_token) prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False) ( prompt_embeds_text_only, negative_prompt_embeds, pooled_prompt_embeds_text_only, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt_text_only, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds_text_only, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds_text_only, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 6. Prepare the input ID images dtype = next(self.id_encoder.parameters()).dtype if not isinstance(input_id_images[0], torch.Tensor): id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values # pylint: disable=used-before-assignment id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # pylint: disable=used-before-assignment # 7. Get the update text embedding with the stacked ID embedding if id_embeds is not None: id_embeds = id_embeds.unsqueeze(0).to(device=device, dtype=dtype) prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds) else: prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # 8. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 9. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 11.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps # pylint: disable=no-member - (self.denoising_end * self.scheduler.config.num_train_timesteps) # pylint: disable=no-member ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 12. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) # pylint: disable=attribute-defined-outside-init with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if i <= start_merge_step: current_prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds_text_only], dim=0 ) if self.do_classifier_free_guidance else prompt_embeds_text_only add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0 ) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only else: current_prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds], dim=0 ) if self.do_classifier_free_guidance else prompt_embeds add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 ) if self.do_classifier_free_guidance else pooled_prompt_embeds added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=current_prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() # pylint: disable=possibly-used-before-assignment if output_type != "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # pylint: disable=attribute-defined-outside-init # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) # apply watermark if available # if self.watermark is not None: # image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: modules/face/reswapper.py ================================================ from typing import List import os import cv2 import torch import numpy as np import huggingface_hub as hf from PIL import Image from modules import processing, shared, devices RESWAPPER_REPO = 'somanchiu/reswapper' RESWAPPER_MODELS = { "ReSwapper 256 0.2": "reswapper_256-1567500.pth", "ReSwapper 256 0.1": "reswapper_256-1399500.pth", "ReSwapper 128 0.2": "reswapper-429500.pth", "ReSwapper 128 0.1": "reswapper-1019500.pth", } reswapper_model = None reswapper_name = None debug = shared.log.trace if os.environ.get("SD_FACE_DEBUG", None) is not None else lambda *args, **kwargs: None dtype = devices.dtype def get_model(model_name: str): global reswapper_model, reswapper_name # pylint: disable=global-statement if reswapper_model is None or reswapper_name != model_name: try: fn = RESWAPPER_MODELS.get(model_name) url = hf.hf_hub_download(repo_id=RESWAPPER_REPO, filename=fn, repo_type="model", cache_dir=shared.opts.hfcache_dir) from modules.face.reswapper_model import ReSwapperModel reswapper_model = ReSwapperModel() reswapper_model.load_state_dict(torch.load(url, map_location='cpu'), strict=False) reswapper_model = reswapper_model.to(device=devices.device, dtype=dtype) reswapper_model.eval() reswapper_name = model_name shared.log.info(f'ReSwapper: model="{model_name}" url="{url}" cls={reswapper_model.__class__.__name__}') if reswapper_model is None: shared.log.error(f'ReSwapper: model="{model_name}" fn="{fn}" url="{url}" failed to load model') return reswapper_model except Exception as e: shared.log.error(f'ReSwapper: model="{model_name}" fn="{fn}" url="{url}" {e}') return reswapper_model def reswapper( p: processing.StableDiffusionProcessing, app, source_images: List[Image.Image], target_images: List[Image.Image], model_name: str, original: bool, ): from modules.face import reswapper_utils as utils if source_images is None or len(source_images) == 0: shared.log.warning('ReSwapper: no input images') return None processed_images = [] if original: processed_images += source_images model = get_model(model_name) if model is None: return source_images model = model.to(device=devices.device) i = 0 for x, image in enumerate(source_images): image = image.convert('RGB') source_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) source_faces = app.get(source_np) if len(source_faces) == 0: shared.log.error(f"ReSwapper: image={x+1} no source faces found") return source_images if len(source_faces) != len(target_images): shared.log.warning(f"ReSwapper: image={x+1} source-faces={len(source_faces)} target-images={len(target_images)}") for y, source_face in enumerate(source_faces): target_image = target_images[y] if y < len(target_images) else target_images[-1] target_image = target_image.convert('RGB') target_np = cv2.cvtColor(np.array(target_image), cv2.COLOR_RGB2BGR) target_faces = app.get(target_np) if len(target_faces) != 1: shared.log.error(f"ReSwapper: image={x+1} source-faces={y+1} target-faces={len(target_faces)} must be exactly one") return source_images target_face = target_faces[0] source_str = f'score:{source_face.det_score:.2f} gender:{"female" if source_face.gender==0 else "male"} age:{source_face.age}' target_str = f'score:{target_face.det_score:.2f} gender:{"female" if target_face.gender==0 else "male"} age:{target_face.age}' shared.log.debug(f'ReSwapper image={x+1} face={y+1} source="{source_str}" target="{target_str}"') source_latent = utils.getLatent(source_face) source_tensor = torch.from_numpy(source_latent).to(device=devices.device, dtype=dtype) resolution = 256 if '256' in model_name else 128 target_np = cv2.cvtColor(np.array(target_image), cv2.COLOR_RGB2BGR) target_aligned, M = utils.norm_crop2(target_np, target_face.kps, resolution) target_blob = utils.getBlob(target_aligned, (resolution, resolution)) target_tensor = torch.from_numpy(target_blob).to(device=devices.device, dtype=dtype) with devices.inference_context(): swapped_tensor = model(target_tensor, source_tensor) swapped_tensor = swapped_tensor.float() swapped_face = (swapped_tensor.squeeze().permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8) swapped_face = cv2.cvtColor(swapped_face, cv2.COLOR_RGB2BGR) swapped_np = utils.blend_swapped_image(swapped_face, source_np, M) swapped_image = Image.fromarray(cv2.cvtColor(swapped_np, cv2.COLOR_BGR2RGB)) processed_images.append(swapped_image) i += 1 p.extra_generation_params['ReSwapper'] = f'faces={i}' devices.torch_gc() return processed_images ================================================ FILE: modules/face/reswapper_model.py ================================================ # original: import torch import torch.nn as nn import torch.nn.functional as F class ReSwapperModel(nn.Module): def __init__(self): super(ReSwapperModel, self).__init__() # self.pad = nn.ReflectionPad2d(3) # Encoder for target face self.target_encoder = nn.Sequential( # self.pad, nn.Conv2d(3, 128, kernel_size=7, stride=1, padding=0), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), nn.LeakyReLU(0.2), ) # for style_block in self.target_encoder: # for param in style_block.parameters(): # param.requires_grad = False # Style blocks self.style_blocks = nn.ModuleList([ StyleBlock(1024, 1024, blockIndex) for blockIndex in range(6) ]) # Decoder (upsampling) self.decoder = nn.Sequential( nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2) ) self.decoderPart1 = nn.Sequential( nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2) ) self.decoderPart2 = nn.Sequential( # self.pad, nn.Conv2d(128, 3, kernel_size=7, stride=1, padding=0), nn.Tanh() ) def forward(self, target, source): # Encode target face target = F.pad(target, pad=(3, 3, 3, 3), mode='reflect') target_features = self.target_encoder(target) # Apply style blocks x = target_features for style_block in self.style_blocks: x = style_block(x, source) # Decode # x = F.interpolate(x, scale_factor=2, mode='linear') x = F.upsample( x, scale_factor=2, # specify the desired height and width mode='bilinear', # 'linear' in 2D is called 'bilinear' align_corners=False # this is typically False for ONNX compatibility ) output = self.decoder(x) output = F.upsample( output, scale_factor=2, # specify the desired height and width mode='bilinear', # 'linear' in 2D is called 'bilinear' align_corners=False # this is typically False for ONNX compatibility ) output = self.decoderPart1(output) output = F.pad(output, pad=(3, 3, 3, 3), mode='reflect') output = self.decoderPart2(output) return (output + 1) / 2 class StyleBlock(nn.Module): def __init__(self, in_channels, out_channels, blockIndex): super(StyleBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0) self.style1 = nn.Linear(512, 2048) self.style2 = nn.Linear(512, 2048) self.style = [self.style1, self.style2] self.blockIndex = blockIndex def normalizeConvRMS(self, conv): x = conv - torch.mean(conv, dim=[2, 3], keepdim=True) # centeredConv squareX = x * x meanSquaredX = torch.mean(squareX, dim=[2, 3], keepdim=True) rms = torch.sqrt(meanSquaredX + 0.00000001) return (1 / rms) * x def forward(self, residual, style): # print(f'Forward: {self.blockIndex}') style1024 = [] for index in range(2): style1 = self.style[index](style) style1 = torch.unsqueeze(style1, 2) style1 = torch.unsqueeze(style1, 3) first_half = style1[:, :1024, :, :] second_half = style1[:, 1024:, :, :] style1024.append([first_half, second_half]) conv1 = self.normalizeConvRMS(self.conv1(F.pad(residual, pad=(1, 1, 1, 1), mode='reflect'))) out = F.relu(conv1 * style1024[0][0] + style1024[0][1]) out = F.pad(out, pad=(1, 1, 1, 1), mode='reflect') conv2 = self.normalizeConvRMS(self.conv2(out)) out = conv2 * style1024[1][0] + style1024[1][1] return residual + out ================================================ FILE: modules/face/reswapper_utils.py ================================================ # https://github.com/somanchiu/ReSwapper/blob/GAN/Image.py import cv2 import numpy as np input_std = 255.0 input_mean = 0.0 def get_emap(): emap = np.load("modules/face/reswapper_emap.npy") # https://github.com/somanchiu/ReSwapper/blob/GAN/emap.npy return emap def postprocess_face(face_tensor): face_tensor = face_tensor.squeeze().cpu().detach() face_np = (face_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) face_np = cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR) return face_np def getBlob(aimg, input_size = (128, 128)): blob = cv2.dnn.blobFromImage(aimg, 1.0 / input_std, input_size, (input_mean, input_mean, input_mean), swapRB=True) return blob def getLatent(source_face): latent = source_face.normed_embedding.reshape((1,-1)) emap = get_emap() latent = np.dot(latent, emap) latent /= np.linalg.norm(latent) return latent def blend_swapped_image(swapped_face, target_image, M): h, w = target_image.shape[:2] M_inv = cv2.invertAffineTransform(M) warped_face = cv2.warpAffine(swapped_face, M_inv, (w, h),borderValue=0.0) img_white = np.full((swapped_face.shape[0], swapped_face.shape[1]), 255, dtype=np.float32) img_mask = cv2.warpAffine(img_white, M_inv, (w, h), borderValue=0.0) img_mask[img_mask > 20] = 255 # pylint: disable=unsupported-assignment-operation mask_h_inds, mask_w_inds = np.where(img_mask == 255) if len(mask_h_inds) > 0 and len(mask_w_inds) > 0: # safety check mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) mask_size = int(np.sqrt(mask_h * mask_w)) k = max(mask_size // 10, 10) kernel = np.ones((k, k), np.uint8) img_mask = cv2.erode(img_mask, kernel, iterations=1) k = max(mask_size // 20, 5) kernel_size = (k, k) blur_size = tuple(2 * i + 1 for i in kernel_size) img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) img_mask = img_mask / 255.0 img_mask = np.reshape(img_mask, [img_mask.shape[0], img_mask.shape[1], 1]) result = img_mask * warped_face + (1 - img_mask) * target_image.astype(np.float32) result = result.astype(np.uint8) return result def drawKeypoints(image, keypoints, colorBGR, keypointsRadius=2): for kp in keypoints: x, y = int(kp[0]), int(kp[1]) cv2.circle(image, (x, y), radius=keypointsRadius, color=colorBGR, thickness=-1) # BGR format, -1 means filled circle ### https://github.com/somanchiu/ReSwapper/blob/GAN/face_align.py arcface_dst = np.array( [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) def estimate_norm(lmk, image_size=112,mode='arcface'): # pylint: disable=unused-argument from skimage import transform as trans if image_size%112==0: ratio = float(image_size)/112.0 diff_x = 0 else: ratio = float(image_size)/128.0 diff_x = 8.0*ratio ratio = float(image_size)/112.0 diff_x = 0 dst = arcface_dst * ratio dst[:,0] += diff_x if image_size%112==0: ratio = float(image_size)/112.0 diff_x = 0 else: ratio = float(image_size)/128.0 diff_x = 8.0*ratio dst = arcface_dst * ratio dst[:,0] += diff_x tform = trans.SimilarityTransform() tform.estimate(lmk, dst) M = tform.params[0:2, :] return M def norm_crop(img, landmark, image_size=112, mode='arcface'): M = estimate_norm(landmark, image_size, mode) warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) return warped def norm_crop2(img, landmark, image_size=112, mode='arcface'): M = estimate_norm(landmark, image_size, mode) warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) return warped, M def square_crop(im, S): if im.shape[0] > im.shape[1]: height = S width = int(float(im.shape[1]) / im.shape[0] * S) scale = float(S) / im.shape[0] else: width = S height = int(float(im.shape[0]) / im.shape[1] * S) scale = float(S) / im.shape[1] resized_im = cv2.resize(im, (width, height)) det_im = np.zeros((S, S, 3), dtype=np.uint8) det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im return det_im, scale def transform(data, center, output_size, scale, rotation): from skimage import transform as trans scale_ratio = scale rot = float(rotation) * np.pi / 180.0 t1 = trans.SimilarityTransform(scale=scale_ratio) cx = center[0] * scale_ratio cy = center[1] * scale_ratio t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) t3 = trans.SimilarityTransform(rotation=rot) t4 = trans.SimilarityTransform(translation=(output_size / 2, output_size / 2)) t = t1 + t2 + t3 + t4 M = t.params[0:2] cropped = cv2.warpAffine(data, M, (output_size, output_size), borderValue=0.0) return cropped, M def trans_points2d(pts, M): new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) new_pt = np.dot(M, new_pt) new_pts[i] = new_pt[0:2] return new_pts def trans_points3d(pts, M): scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) #print(scale) new_pts = np.zeros(shape=pts.shape, dtype=np.float32) for i in range(pts.shape[0]): pt = pts[i] new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) new_pt = np.dot(M, new_pt) #print('new_pt', new_pt.shape, new_pt) new_pts[i][0:2] = new_pt[0:2] new_pts[i][2] = pts[i][2] * scale return new_pts def trans_points(pts, M): if pts.shape[1] == 2: return trans_points2d(pts, M) else: return trans_points3d(pts, M) ================================================ FILE: modules/face_restoration.py ================================================ from modules import shared class FaceRestoration: def name(self): return "None" def restore(self, np_image): return np_image def restore_faces(np_image, p=None): face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] if len(face_restorers) == 0: return np_image face_restorer = face_restorers[0] return face_restorer.restore(np_image, p) ================================================ FILE: modules/facelib/__init__.py ================================================ ================================================ FILE: modules/facelib/detection/__init__.py ================================================ import os from copy import deepcopy import torch from torch import nn from ..utils import load_file_from_url from ..utils import download_pretrained_models from ..detection.yolov5face.models.common import Conv from .retinaface.retinaface import RetinaFace from .yolov5face.face_detector import YoloDetector from modules import paths model_dir = os.path.join(paths.models_path, 'Codeformer') def init_detection_model(model_name, half=False, device='cuda'): if 'retinaface' in model_name: model = init_retinaface_model(model_name, half, device) elif 'YOLOv5' in model_name: model = init_yolov5face_model(model_name, device) else: raise NotImplementedError(f'{model_name} is not implemented.') return model def init_retinaface_model(model_name, half=False, device='cuda'): if model_name == 'retinaface_resnet50': model = RetinaFace(network_name='resnet50', half=half) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' elif model_name == 'retinaface_mobile0.25': model = RetinaFace(network_name='mobile0.25', half=half) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') model_path = load_file_from_url(url=model_url, model_dir=model_dir, progress=True, file_name=None) load_net = torch.load(model_path, map_location=lambda storage, loc: storage) # remove unnecessary 'module.' for k, v in deepcopy(load_net).items(): if k.startswith('module.'): load_net[k[7:]] = v load_net.pop(k) model.load_state_dict(load_net, strict=True) model.eval() model = model.to(device) return model def init_yolov5face_model(model_name, device='cuda'): if model_name == 'YOLOv5l': model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' elif model_name == 'YOLOv5n': model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') model_path = load_file_from_url(url=model_url, model_dir=model_dir, progress=True, file_name=None) load_net = torch.load(model_path, map_location=lambda storage, loc: storage) model.detector.load_state_dict(load_net, strict=True) model.detector.eval() model.detector = model.detector.to(device).float() for m in model.detector.modules(): if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: m.inplace = True # pytorch 1.7.0 compatibility elif isinstance(m, Conv): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility return model ================================================ FILE: modules/facelib/detection/align_trans.py ================================================ import cv2 import numpy as np from .matlab_cp2tform import get_similarity_transform_for_cv2 # reference facial points, a list of coordinates (x,y) REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], [33.54930115, 92.3655014], [62.72990036, 92.20410156]] DEFAULT_CROP_SIZE = (96, 112) class FaceWarpException(Exception): def __str__(self): return 'In File {}:{}'.format(__file__, super.__str__(self)) def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): """ Function: ---------- get reference 5 key points according to crop settings: 0. Set default crop_size: if default_square: crop_size = (112, 112) else: crop_size = (96, 112) 1. Pad the crop_size by inner_padding_factor in each side; 2. Resize crop_size into (output_size - outer_padding*2), pad into output_size with outer_padding; 3. Output reference_5point; Parameters: ---------- @output_size: (w, h) or None size of aligned face image @inner_padding_factor: (w_factor, h_factor) padding factor for inner (w, h) @outer_padding: (w_pad, h_pad) each row is a pair of coordinates (x, y) @default_square: True or False if True: default crop_size = (112, 112) else: default crop_size = (96, 112); !!! make sure, if output_size is not None: (output_size - outer_padding) = some_scale * (default crop_size * (1.0 + inner_padding_factor)) Returns: ---------- @reference_5point: 5x2 np.array each row is a pair of transformed coordinates (x, y) """ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) tmp_crop_size = np.array(DEFAULT_CROP_SIZE) # 0) make the inner region a square if default_square: size_diff = max(tmp_crop_size) - tmp_crop_size tmp_5pts += size_diff / 2 tmp_crop_size += size_diff if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): return tmp_5pts if (inner_padding_factor == 0 and outer_padding == (0, 0)): if output_size is None: return tmp_5pts else: raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) # check output size if not (0 <= inner_padding_factor <= 1.0): raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): output_size = tmp_crop_size * \ (1 + inner_padding_factor * 2).astype(np.int32) output_size += np.array(outer_padding) if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') # 1) pad the inner region according inner_padding_factor if inner_padding_factor > 0: size_diff = tmp_crop_size * inner_padding_factor * 2 tmp_5pts += size_diff / 2 tmp_crop_size += np.round(size_diff).astype(np.int32) # 2) resize the padded inner region size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: raise FaceWarpException('Must have (output_size - outer_padding)' '= some_scale * (crop_size * (1.0 + inner_padding_factor)') scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] tmp_5pts = tmp_5pts * scale_factor # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) # tmp_5pts = tmp_5pts + size_diff / 2 tmp_crop_size = size_bf_outer_pad # 3) add outer_padding to make output_size reference_5point = tmp_5pts + np.array(outer_padding) tmp_crop_size = output_size return reference_5point def get_affine_transform_matrix(src_pts, dst_pts): """ Function: ---------- get affine transform matrix 'tfm' from src_pts to dst_pts Parameters: ---------- @src_pts: Kx2 np.array source points matrix, each row is a pair of coordinates (x, y) @dst_pts: Kx2 np.array destination points matrix, each row is a pair of coordinates (x, y) Returns: ---------- @tfm: 2x3 np.array transform matrix from src_pts to dst_pts """ tfm = np.float32([[1, 0, 0], [0, 1, 0]]) n_pts = src_pts.shape[0] ones = np.ones((n_pts, 1), src_pts.dtype) src_pts_ = np.hstack([src_pts, ones]) dst_pts_ = np.hstack([dst_pts, ones]) A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) if rank == 3: tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) elif rank == 2: tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) return tfm def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): """ Function: ---------- apply affine transform 'trans' to uv Parameters: ---------- @src_img: 3x3 np.array input image @facial_pts: could be 1)a list of K coordinates (x,y) or 2) Kx2 or 2xK np.array each row or col is a pair of coordinates (x, y) @reference_pts: could be 1) a list of K coordinates (x,y) or 2) Kx2 or 2xK np.array each row or col is a pair of coordinates (x, y) or 3) None if None, use default reference facial points @crop_size: (w, h) output face image size @align_type: transform type, could be one of 1) 'similarity': use similarity transform 2) 'cv2_affine': use the first 3 points to do affine transform, by calling cv2.getAffineTransform() 3) 'affine': use all points to do affine transform Returns: ---------- @face_img: output face image with size (w, h) = @crop_size """ if reference_pts is None: if crop_size[0] == 96 and crop_size[1] == 112: reference_pts = REFERENCE_FACIAL_POINTS else: default_square = False inner_padding_factor = 0 outer_padding = (0, 0) output_size = crop_size reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, default_square) ref_pts = np.float32(reference_pts) ref_pts_shp = ref_pts.shape if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') if ref_pts_shp[0] == 2: ref_pts = ref_pts.T src_pts = np.float32(facial_pts) src_pts_shp = src_pts.shape if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') if src_pts_shp[0] == 2: src_pts = src_pts.T if src_pts.shape != ref_pts.shape: raise FaceWarpException('facial_pts and reference_pts must have the same shape') if align_type == 'cv2_affine': tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) elif align_type == 'affine': tfm = get_affine_transform_matrix(src_pts, ref_pts) else: tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) return face_img ================================================ FILE: modules/facelib/detection/matlab_cp2tform.py ================================================ import numpy as np from numpy.linalg import inv, lstsq from numpy.linalg import matrix_rank as rank from numpy.linalg import norm class MatlabCp2tormException(Exception): def __str__(self): return 'In File {}:{}'.format(__file__, super.__str__(self)) def tformfwd(trans, uv): """ Function: ---------- apply affine transform 'trans' to uv Parameters: ---------- @trans: 3x3 np.array transform matrix @uv: Kx2 np.array each row is a pair of coordinates (x, y) Returns: ---------- @xy: Kx2 np.array each row is a pair of transformed coordinates (x, y) """ uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) xy = np.dot(uv, trans) xy = xy[:, 0:-1] return xy def tforminv(trans, uv): """ Function: ---------- apply the inverse of affine transform 'trans' to uv Parameters: ---------- @trans: 3x3 np.array transform matrix @uv: Kx2 np.array each row is a pair of coordinates (x, y) Returns: ---------- @xy: Kx2 np.array each row is a pair of inverse-transformed coordinates (x, y) """ Tinv = inv(trans) xy = tformfwd(Tinv, uv) return xy def findNonreflectiveSimilarity(uv, xy, options=None): options = {'K': 2} K = options['K'] M = xy.shape[0] x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) X = np.vstack((tmp1, tmp2)) u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector U = np.vstack((u, v)) # We know that X * r = U if rank(X) >= 2 * K: r, _, _, _ = lstsq(X, U, rcond=-1) r = np.squeeze(r) else: raise Exception('cp2tform:twoUniquePointsReq') sc = r[0] ss = r[1] tx = r[2] ty = r[3] Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) T = inv(Tinv) T[:, 2] = np.array([0, 0, 1]) return T, Tinv def findSimilarity(uv, xy, options=None): options = {'K': 2} # uv = np.array(uv) # xy = np.array(xy) # Solve for trans1 trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) # Solve for trans2 # manually reflect the xy data across the Y-axis xyR = xy xyR[:, 0] = -1 * xyR[:, 0] trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) # manually reflect the tform to undo the reflection done on xyR TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) trans2 = np.dot(trans2r, TreflectY) # Figure out if trans1 or trans2 is better xy1 = tformfwd(trans1, uv) norm1 = norm(xy1 - xy) xy2 = tformfwd(trans2, uv) norm2 = norm(xy2 - xy) if norm1 <= norm2: return trans1, trans1_inv else: trans2_inv = inv(trans2) return trans2, trans2_inv def get_similarity_transform(src_pts, dst_pts, reflective=True): """ Function: ---------- Find Similarity Transform Matrix 'trans': u = src_pts[:, 0] v = src_pts[:, 1] x = dst_pts[:, 0] y = dst_pts[:, 1] [x, y, 1] = [u, v, 1] * trans Parameters: ---------- @src_pts: Kx2 np.array source points, each row is a pair of coordinates (x, y) @dst_pts: Kx2 np.array destination points, each row is a pair of transformed coordinates (x, y) @reflective: True or False if True: use reflective similarity transform else: use non-reflective similarity transform Returns: ---------- @trans: 3x3 np.array transform matrix from uv to xy trans_inv: 3x3 np.array inverse of trans, transform matrix from xy to uv """ if reflective: trans, trans_inv = findSimilarity(src_pts, dst_pts) else: trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) return trans, trans_inv def cvt_tform_mat_for_cv2(trans): """ Function: ---------- Convert Transform Matrix 'trans' into 'cv2_trans' which could be directly used by cv2.warpAffine(): u = src_pts[:, 0] v = src_pts[:, 1] x = dst_pts[:, 0] y = dst_pts[:, 1] [x, y].T = cv_trans * [u, v, 1].T Parameters: ---------- @trans: 3x3 np.array transform matrix from uv to xy Returns: ---------- @cv2_trans: 2x3 np.array transform matrix from src_pts to dst_pts, could be directly used for cv2.warpAffine() """ cv2_trans = trans[:, 0:2].T return cv2_trans def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): """ Function: ---------- Find Similarity Transform Matrix 'cv2_trans' which could be directly used by cv2.warpAffine(): u = src_pts[:, 0] v = src_pts[:, 1] x = dst_pts[:, 0] y = dst_pts[:, 1] [x, y].T = cv_trans * [u, v, 1].T Parameters: ---------- @src_pts: Kx2 np.array source points, each row is a pair of coordinates (x, y) @dst_pts: Kx2 np.array destination points, each row is a pair of transformed coordinates (x, y) reflective: True or False if True: use reflective similarity transform else: use non-reflective similarity transform Returns: ---------- @cv2_trans: 2x3 np.array transform matrix from src_pts to dst_pts, could be directly used for cv2.warpAffine() """ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) cv2_trans = cvt_tform_mat_for_cv2(trans) return cv2_trans if __name__ == '__main__': """ u = [0, 6, -2] v = [0, 3, 5] x = [-1, 0, 4] y = [-1, -10, 4] # In Matlab, run: # # uv = [u'; v']; # xy = [x'; y']; # tform_sim=cp2tform(uv,xy,'similarity'); # # trans = tform_sim.tdata.T # ans = # -0.0764 -1.6190 0 # 1.6190 -0.0764 0 # -3.2156 0.0290 1.0000 # trans_inv = tform_sim.tdata.Tinv # ans = # # -0.0291 0.6163 0 # -0.6163 -0.0291 0 # -0.0756 1.9826 1.0000 # xy_m=tformfwd(tform_sim, u,v) # # xy_m = # # -3.2156 0.0290 # 1.1833 -9.9143 # 5.0323 2.8853 # uv_m=tforminv(tform_sim, x,y) # # uv_m = # # 0.5698 1.3953 # 6.0872 2.2733 # -2.6570 4.3314 """ u = [0, 6, -2] v = [0, 3, 5] x = [-1, 0, 4] y = [-1, -10, 4] uv = np.array((u, v)).T xy = np.array((x, y)).T print('\n--->uv:') print(uv) print('\n--->xy:') print(xy) trans, trans_inv = get_similarity_transform(uv, xy) print('\n--->trans matrix:') print(trans) print('\n--->trans_inv matrix:') print(trans_inv) print('\n---> apply transform to uv') print('\nxy_m = uv_augmented * trans') uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) xy_m = np.dot(uv_aug, trans) print(xy_m) print('\nxy_m = tformfwd(trans, uv)') xy_m = tformfwd(trans, uv) print(xy_m) print('\n---> apply inverse transform to xy') print('\nuv_m = xy_augmented * trans_inv') xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) uv_m = np.dot(xy_aug, trans_inv) print(uv_m) print('\nuv_m = tformfwd(trans_inv, xy)') uv_m = tformfwd(trans_inv, xy) print(uv_m) uv_m = tforminv(trans, xy) print('\nuv_m = tforminv(trans, xy)') print(uv_m) ================================================ FILE: modules/facelib/detection/retinaface/retinaface.py ================================================ import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter from ...detection.align_trans import get_reference_facial_points, warp_and_crop_face from ...detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head from ...detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, py_cpu_nms) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def generate_config(network_name): cfg_mnet = { 'name': 'mobilenet0.25', 'min_sizes': [[16, 32], [64, 128], [256, 512]], 'steps': [8, 16, 32], 'variance': [0.1, 0.2], 'clip': False, 'loc_weight': 2.0, 'gpu_train': True, 'batch_size': 32, 'ngpu': 1, 'epoch': 250, 'decay1': 190, 'decay2': 220, 'image_size': 640, 'return_layers': { 'stage1': 1, 'stage2': 2, 'stage3': 3 }, 'in_channel': 32, 'out_channel': 64 } cfg_re50 = { 'name': 'Resnet50', 'min_sizes': [[16, 32], [64, 128], [256, 512]], 'steps': [8, 16, 32], 'variance': [0.1, 0.2], 'clip': False, 'loc_weight': 2.0, 'gpu_train': True, 'batch_size': 24, 'ngpu': 4, 'epoch': 100, 'decay1': 70, 'decay2': 90, 'image_size': 840, 'return_layers': { 'layer2': 1, 'layer3': 2, 'layer4': 3 }, 'in_channel': 256, 'out_channel': 256 } if network_name == 'mobile0.25': return cfg_mnet elif network_name == 'resnet50': return cfg_re50 else: raise NotImplementedError(f'network_name={network_name}') class RetinaFace(nn.Module): def __init__(self, network_name='resnet50', half=False, phase='test'): super(RetinaFace, self).__init__() self.half_inference = half cfg = generate_config(network_name) self.backbone = cfg['name'] self.model_name = f'retinaface_{network_name}' self.cfg = cfg self.phase = phase self.target_size, self.max_size = 1600, 2150 self.resize, self.scale, self.scale1 = 1., None, None self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) self.reference = get_reference_facial_points(default_square=True) # Build network. backbone = None if cfg['name'] == 'mobilenet0.25': backbone = MobileNetV1() self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) elif cfg['name'] == 'Resnet50': import torchvision.models as models backbone = models.resnet50(pretrained=False) self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) in_channels_stage2 = cfg['in_channel'] in_channels_list = [ in_channels_stage2 * 2, in_channels_stage2 * 4, in_channels_stage2 * 8, ] out_channels = cfg['out_channel'] self.fpn = FPN(in_channels_list, out_channels) self.ssh1 = SSH(out_channels, out_channels) self.ssh2 = SSH(out_channels, out_channels) self.ssh3 = SSH(out_channels, out_channels) self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) self.to(device) self.eval() if self.half_inference: self.half() def forward(self, inputs): out = self.body(inputs) if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': out = list(out.values()) # FPN fpn = self.fpn(out) # SSH feature1 = self.ssh1(fpn[0]) feature2 = self.ssh2(fpn[1]) feature3 = self.ssh3(fpn[2]) features = [feature1, feature2, feature3] bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] ldm_regressions = (torch.cat(tmp, dim=1)) if self.phase == 'train': output = (bbox_regressions, classifications, ldm_regressions) else: output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) return output def __detect_faces(self, inputs): # get scale height, width = inputs.shape[2:] self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) tmp = [width, height, width, height, width, height, width, height, width, height] self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) # forawrd inputs = inputs.to(device) if self.half_inference: inputs = inputs.half() loc, conf, landmarks = self(inputs) # get priorbox priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) priors = priorbox.forward().to(device) return loc, conf, landmarks, priors # single image detection def transform(self, image, use_origin_size): # convert to opencv format if isinstance(image, Image.Image): image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) image = image.astype(np.float32) # testing scale im_size_min = np.min(image.shape[0:2]) im_size_max = np.max(image.shape[0:2]) resize = float(self.target_size) / float(im_size_min) # prevent bigger axis from being more than max_size if np.round(resize * im_size_max) > self.max_size: resize = float(self.max_size) / float(im_size_max) resize = 1 if use_origin_size else resize # resize if resize != 1: image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) # convert to torch.tensor format # image -= (104, 117, 123) image = image.transpose(2, 0, 1) image = torch.from_numpy(image).unsqueeze(0) return image, resize def detect_faces( self, image, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True, ): """ Params: imgs: BGR image """ image, self.resize = self.transform(image, use_origin_size) image = image.to(device) if self.half_inference: image = image.half() image = image - self.mean_tensor loc, conf, landmarks, priors = self.__detect_faces(image) boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) boxes = boxes * self.scale / self.resize boxes = boxes.cpu().numpy() scores = conf.squeeze(0).data.cpu().numpy()[:, 1] landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) landmarks = landmarks * self.scale1 / self.resize landmarks = landmarks.cpu().numpy() # ignore low scores inds = np.where(scores > conf_threshold)[0] boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] # sort order = scores.argsort()[::-1] boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] # do NMS bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) keep = py_cpu_nms(bounding_boxes, nms_threshold) bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] # self.t['forward_pass'].toc() # print(self.t['forward_pass'].average_time) # import sys # sys.stdout.flush() return np.concatenate((bounding_boxes, landmarks), axis=1) def __align_multi(self, image, boxes, landmarks, limit=None): if len(boxes) < 1: return [], [] if limit: boxes = boxes[:limit] landmarks = landmarks[:limit] faces = [] for landmark in landmarks: facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) faces.append(warped_face) return np.concatenate((boxes, landmarks), axis=1), faces def align_multi(self, img, conf_threshold=0.8, limit=None): rlt = self.detect_faces(img, conf_threshold=conf_threshold) boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] return self.__align_multi(img, boxes, landmarks, limit) # batched detection def batched_transform(self, frames, use_origin_size): """ Arguments: frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], type=np.float32, BGR format). use_origin_size: whether to use origin size. """ from_PIL = True if isinstance(frames[0], Image.Image) else False # convert to opencv format if from_PIL: frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] frames = np.asarray(frames, dtype=np.float32) # testing scale im_size_min = np.min(frames[0].shape[0:2]) im_size_max = np.max(frames[0].shape[0:2]) resize = float(self.target_size) / float(im_size_min) # prevent bigger axis from being more than max_size if np.round(resize * im_size_max) > self.max_size: resize = float(self.max_size) / float(im_size_max) resize = 1 if use_origin_size else resize # resize if resize != 1: if not from_PIL: frames = F.interpolate(frames, scale_factor=resize) else: frames = [ cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) for frame in frames ] # convert to torch.tensor format if not from_PIL: frames = frames.transpose(1, 2).transpose(1, 3).contiguous() else: frames = frames.transpose((0, 3, 1, 2)) frames = torch.from_numpy(frames) return frames, resize def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): """ Arguments: frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], type=np.uint8, BGR format). conf_threshold: confidence threshold. nms_threshold: nms threshold. use_origin_size: whether to use origin size. Returns: final_bounding_boxes: list of np.array ([n_boxes, 5], type=np.float32). final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). """ # self.t['forward_pass'].tic() frames, self.resize = self.batched_transform(frames, use_origin_size) frames = frames.to(device) frames = frames - self.mean_tensor b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) final_bounding_boxes, final_landmarks = [], [] # decode priors = priors.unsqueeze(0) b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize b_conf = b_conf[:, :, 1] # index for selection b_indice = b_conf > conf_threshold # concat b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): # ignore low scores pred, landm = pred[inds, :], landm[inds, :] if pred.shape[0] == 0: final_bounding_boxes.append(np.array([], dtype=np.float32)) final_landmarks.append(np.array([], dtype=np.float32)) continue # sort # order = score.argsort(descending=True) # box, landm, score = box[order], landm[order], score[order] # to CPU bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() # NMS keep = py_cpu_nms(bounding_boxes, nms_threshold) bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] # append final_bounding_boxes.append(bounding_boxes) final_landmarks.append(landmarks) # self.t['forward_pass'].toc(average=True) # self.batch_time += self.t['forward_pass'].diff # self.total_frame += len(frames) # print(self.batch_time / self.total_frame) return final_bounding_boxes, final_landmarks ================================================ FILE: modules/facelib/detection/retinaface/retinaface_net.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F def conv_bn(inp, oup, stride=1, leaky=0): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) def conv_bn_no_relu(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), ) def conv_bn1X1(inp, oup, stride, leaky=0): return nn.Sequential( nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) def conv_dw(inp, oup, stride, leaky=0.1): return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.LeakyReLU(negative_slope=leaky, inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True), ) class SSH(nn.Module): def __init__(self, in_channel, out_channel): super(SSH, self).__init__() assert out_channel % 4 == 0 leaky = 0 if (out_channel <= 64): leaky = 0.1 self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) def forward(self, input): conv3X3 = self.conv3X3(input) conv5X5_1 = self.conv5X5_1(input) conv5X5 = self.conv5X5_2(conv5X5_1) conv7X7_2 = self.conv7X7_2(conv5X5_1) conv7X7 = self.conv7x7_3(conv7X7_2) out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) out = F.relu(out) return out class FPN(nn.Module): def __init__(self, in_channels_list, out_channels): super(FPN, self).__init__() leaky = 0 if (out_channels <= 64): leaky = 0.1 self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) def forward(self, input): # names = list(input.keys()) # input = list(input.values()) output1 = self.output1(input[0]) output2 = self.output2(input[1]) output3 = self.output3(input[2]) up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') output2 = output2 + up3 output2 = self.merge2(output2) up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') output1 = output1 + up2 output1 = self.merge1(output1) out = [output1, output2, output3] return out class MobileNetV1(nn.Module): def __init__(self): super(MobileNetV1, self).__init__() self.stage1 = nn.Sequential( conv_bn(3, 8, 2, leaky=0.1), # 3 conv_dw(8, 16, 1), # 7 conv_dw(16, 32, 2), # 11 conv_dw(32, 32, 1), # 19 conv_dw(32, 64, 2), # 27 conv_dw(64, 64, 1), # 43 ) self.stage2 = nn.Sequential( conv_dw(64, 128, 2), # 43 + 16 = 59 conv_dw(128, 128, 1), # 59 + 32 = 91 conv_dw(128, 128, 1), # 91 + 32 = 123 conv_dw(128, 128, 1), # 123 + 32 = 155 conv_dw(128, 128, 1), # 155 + 32 = 187 conv_dw(128, 128, 1), # 187 + 32 = 219 ) self.stage3 = nn.Sequential( conv_dw(128, 256, 2), # 219 +3 2 = 241 conv_dw(256, 256, 1), # 241 + 64 = 301 ) self.avg = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256, 1000) def forward(self, x): x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.avg(x) # x = self.model(x) x = x.view(-1, 256) x = self.fc(x) return x class ClassHead(nn.Module): def __init__(self, inchannels=512, num_anchors=3): super(ClassHead, self).__init__() self.num_anchors = num_anchors self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) def forward(self, x): out = self.conv1x1(x) out = out.permute(0, 2, 3, 1).contiguous() return out.view(out.shape[0], -1, 2) class BboxHead(nn.Module): def __init__(self, inchannels=512, num_anchors=3): super(BboxHead, self).__init__() self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) def forward(self, x): out = self.conv1x1(x) out = out.permute(0, 2, 3, 1).contiguous() return out.view(out.shape[0], -1, 4) class LandmarkHead(nn.Module): def __init__(self, inchannels=512, num_anchors=3): super(LandmarkHead, self).__init__() self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) def forward(self, x): out = self.conv1x1(x) out = out.permute(0, 2, 3, 1).contiguous() return out.view(out.shape[0], -1, 10) def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): classhead = nn.ModuleList() for i in range(fpn_num): classhead.append(ClassHead(inchannels, anchor_num)) return classhead def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): bboxhead = nn.ModuleList() for i in range(fpn_num): bboxhead.append(BboxHead(inchannels, anchor_num)) return bboxhead def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): landmarkhead = nn.ModuleList() for i in range(fpn_num): landmarkhead.append(LandmarkHead(inchannels, anchor_num)) return landmarkhead ================================================ FILE: modules/facelib/detection/retinaface/retinaface_utils.py ================================================ import numpy as np import torch import torchvision from itertools import product as product from math import ceil class PriorBox(object): def __init__(self, cfg, image_size=None, phase='train'): super(PriorBox, self).__init__() self.min_sizes = cfg['min_sizes'] self.steps = cfg['steps'] self.clip = cfg['clip'] self.image_size = image_size self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] self.name = 's' def forward(self): anchors = [] for k, f in enumerate(self.feature_maps): min_sizes = self.min_sizes[k] for i, j in product(range(f[0]), range(f[1])): for min_size in min_sizes: s_kx = min_size / self.image_size[1] s_ky = min_size / self.image_size[0] dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] for cy, cx in product(dense_cy, dense_cx): anchors += [cx, cy, s_kx, s_ky] # back to torch land output = torch.Tensor(anchors).view(-1, 4) if self.clip: output.clamp_(max=1, min=0) return output def py_cpu_nms(dets, thresh): """Pure Python NMS baseline.""" keep = torchvision.ops.nms( boxes=torch.Tensor(dets[:, :4]), scores=torch.Tensor(dets[:, 4]), iou_threshold=thresh, ) return list(keep) def point_form(boxes): """ Convert prior_boxes to (xmin, ymin, xmax, ymax) representation for comparison to point form ground truth data. Args: boxes: (tensor) center-size default boxes from priorbox layers. Return: boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. """ return torch.cat( ( boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmax, ymax def center_size(boxes): """ Convert prior_boxes to (cx, cy, w, h) representation for comparison to center-size form ground truth data. Args: boxes: (tensor) point_form boxes Return: boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. """ return torch.cat( (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy boxes[:, 2:] - boxes[:, :2], 1) # w, h def intersect(box_a, box_b): """ We resize both tensors to [A,B,2] without new malloc: [A,2] -> [A,1,2] -> [A,B,2] [B,2] -> [1,B,2] -> [A,B,2] Then we compute the area of intersect between box_a and box_b. Args: box_a: (tensor) bounding boxes, Shape: [A,4]. box_b: (tensor) bounding boxes, Shape: [B,4]. Return: (tensor) intersection area, Shape: [A,B]. """ A = box_a.size(0) B = box_b.size(0) max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) inter = torch.clamp((max_xy - min_xy), min=0) return inter[:, :, 0] * inter[:, :, 1] def jaccard(box_a, box_b): """Compute the jaccard overlap of two sets of boxes. The jaccard overlap is simply the intersection over union of two boxes. Here we operate on ground truth boxes and default boxes. E.g.: A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) Args: box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] Return: jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] """ inter = intersect(box_a, box_b) area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] union = area_a + area_b - inter return inter / union # [A,B] def matrix_iou(a, b): """ return iou of a and b, numpy version for data augenmentation """ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) return area_i / (area_a[:, np.newaxis] + area_b - area_i) def matrix_iof(a, b): """ return iof of a and b, numpy version for data augenmentation """ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) return area_i / np.maximum(area_a[:, np.newaxis], 1) def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): """Match each prior box with the ground truth box of the highest jaccard overlap, encode the bounding boxes, then return the matched indices corresponding to both confidence and location preds. Args: threshold: (float) The overlap threshold used when matching boxes. truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. variances: (tensor) Variances corresponding to each prior coord, Shape: [num_priors, 4]. labels: (tensor) All the class labels for the image, Shape: [num_obj]. landms: (tensor) Ground truth landms, Shape [num_obj, 10]. loc_t: (tensor) Tensor to be filled w/ encoded location targets. conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. landm_t: (tensor) Tensor to be filled w/ encoded landm targets. idx: (int) current batch index Return: The matched indices corresponding to 1)location 2)confidence 3)landm preds. """ # jaccard index overlaps = jaccard(truths, point_form(priors)) # (Bipartite Matching) # [1,num_objects] best prior for each ground truth best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) # ignore hard gt valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] if best_prior_idx_filter.shape[0] <= 0: loc_t[idx] = 0 conf_t[idx] = 0 return # [1,num_priors] best ground truth for each prior best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_truth_idx.squeeze_(0) best_truth_overlap.squeeze_(0) best_prior_idx.squeeze_(1) best_prior_idx_filter.squeeze_(1) best_prior_overlap.squeeze_(1) best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior # ensure every gt matches with its prior of max overlap for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes best_truth_idx[best_prior_idx[j]] = j matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 loc = encode(matches, priors, variances) matches_landm = landms[best_truth_idx] landm = encode_landm(matches_landm, priors, variances) loc_t[idx] = loc # [num_priors,4] encoded offsets to learn conf_t[idx] = conf # [num_priors] top class label for each prior landm_t[idx] = landm def encode(matched, priors, variances): """Encode the variances from the priorbox layers into the ground truth boxes we have matched (based on jaccard overlap) with the prior boxes. Args: matched: (tensor) Coords of ground truth for each prior in point-form Shape: [num_priors, 4]. priors: (tensor) Prior boxes in center-offset form Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: encoded boxes (tensor), Shape: [num_priors, 4] """ # dist b/t match center and prior's center g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] # encode variance g_cxcy /= (variances[0] * priors[:, 2:]) # match wh / prior wh g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] g_wh = torch.log(g_wh) / variances[1] # return target for smooth_l1_loss return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] def encode_landm(matched, priors, variances): """Encode the variances from the priorbox layers into the ground truth boxes we have matched (based on jaccard overlap) with the prior boxes. Args: matched: (tensor) Coords of ground truth for each prior in point-form Shape: [num_priors, 10]. priors: (tensor) Prior boxes in center-offset form Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: encoded landm (tensor), Shape: [num_priors, 10] """ # dist b/t match center and prior's center matched = torch.reshape(matched, (matched.size(0), 5, 2)) priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) g_cxcy = matched[:, :, :2] - priors[:, :, :2] # encode variance g_cxcy /= (variances[0] * priors[:, :, 2:]) # g_cxcy /= priors[:, :, 2:] g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) # return target for smooth_l1_loss return g_cxcy # Adapted from https://github.com/Hakuyume/chainer-ssd def decode(loc, priors, variances): """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. Args: loc (tensor): location predictions for loc layers, Shape: [num_priors,4] priors (tensor): Prior boxes in center-offset form. Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: decoded bounding box predictions """ boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) boxes[:, :2] -= boxes[:, 2:] / 2 boxes[:, 2:] += boxes[:, :2] return boxes def decode_landm(pre, priors, variances): """Decode landm from predictions using priors to undo the encoding we did for offset regression at train time. Args: pre (tensor): landm predictions for loc layers, Shape: [num_priors,10] priors (tensor): Prior boxes in center-offset form. Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: decoded landm predictions """ tmp = ( priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], ) landms = torch.cat(tmp, dim=1) return landms def batched_decode(b_loc, priors, variances): """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. Args: b_loc (tensor): location predictions for loc layers, Shape: [num_batches,num_priors,4] priors (tensor): Prior boxes in center-offset form. Shape: [1,num_priors,4]. variances: (list[float]) Variances of priorboxes Return: decoded bounding box predictions """ boxes = ( priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), ) boxes = torch.cat(boxes, dim=2) boxes[:, :, :2] -= boxes[:, :, 2:] / 2 boxes[:, :, 2:] += boxes[:, :, :2] return boxes def batched_decode_landm(pre, priors, variances): """Decode landm from predictions using priors to undo the encoding we did for offset regression at train time. Args: pre (tensor): landm predictions for loc layers, Shape: [num_batches,num_priors,10] priors (tensor): Prior boxes in center-offset form. Shape: [1,num_priors,4]. variances: (list[float]) Variances of priorboxes Return: decoded landm predictions """ landms = ( priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], ) landms = torch.cat(landms, dim=2) return landms def log_sum_exp(x): """Utility function for computing log_sum_exp while determining This will be used to determine unaveraged confidence loss across all examples in a batch. Args: x (Variable(tensor)): conf_preds from conf layers """ x_max = x.data.max() return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max # Original author: Francisco Massa: # https://github.com/fmassa/object-detection.torch # Ported to PyTorch by Max deGroot (02/01/2017) def nms(boxes, scores, overlap=0.5, top_k=200): """Apply non-maximum suppression at test time to avoid detecting too many overlapping bounding boxes for a given object. Args: boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. scores: (tensor) The class predscores for the img, Shape:[num_priors]. overlap: (float) The overlap thresh for suppressing unnecessary boxes. top_k: (int) The Maximum number of box preds to consider. Return: The indices of the kept boxes with respect to num_priors. """ keep = torch.Tensor(scores.size(0)).fill_(0).long() if boxes.numel() == 0: return keep x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] area = torch.mul(x2 - x1, y2 - y1) v, idx = scores.sort(0) # sort in ascending order # I = I[v >= 0.01] idx = idx[-top_k:] # indices of the top-k largest vals xx1 = boxes.new() yy1 = boxes.new() xx2 = boxes.new() yy2 = boxes.new() w = boxes.new() h = boxes.new() # keep = torch.Tensor() count = 0 while idx.numel() > 0: i = idx[-1] # index of current largest val # keep.append(i) keep[count] = i count += 1 if idx.size(0) == 1: break idx = idx[:-1] # remove kept element from view # load bboxes of next highest vals torch.index_select(x1, 0, idx, out=xx1) torch.index_select(y1, 0, idx, out=yy1) torch.index_select(x2, 0, idx, out=xx2) torch.index_select(y2, 0, idx, out=yy2) # store element-wise max with next highest score xx1 = torch.clamp(xx1, min=x1[i]) yy1 = torch.clamp(yy1, min=y1[i]) xx2 = torch.clamp(xx2, max=x2[i]) yy2 = torch.clamp(yy2, max=y2[i]) w.resize_as_(xx2) h.resize_as_(yy2) w = xx2 - xx1 h = yy2 - yy1 # check sizes of xx1 and xx2.. after each iteration w = torch.clamp(w, min=0.0) h = torch.clamp(h, min=0.0) inter = w * h # IoU = i / (area(a) + area(b) - i) rem_areas = torch.index_select(area, 0, idx) # load remaining areas) union = (rem_areas - inter) + area[i] IoU = inter / union # store result in iou # keep only elements with an IoU <= overlap idx = idx[IoU.le(overlap)] return keep, count ================================================ FILE: modules/facelib/detection/yolov5face/__init__.py ================================================ ================================================ FILE: modules/facelib/detection/yolov5face/face_detector.py ================================================ import copy import os from pathlib import Path import cv2 import numpy as np import torch from torch import nn from ....facelib.detection.yolov5face.models.common import Conv from ....facelib.detection.yolov5face.models.yolo import Model from ....facelib.detection.yolov5face.utils.datasets import letterbox from ....facelib.detection.yolov5face.utils.general import ( check_img_size, non_max_suppression_face, scale_coords, scale_coords_landmarks, ) IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9, 0) def isListempty(inList): if isinstance(inList, list): # Is a list return all(map(isListempty, inList)) return False # Not a list class YoloDetector: def __init__( self, config_name, min_face=10, target_size=None, device='cuda', ): """ config_name: name of .yaml config with network configuration from models/ folder. min_face : minimal face size in pixels. target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. None for original resolution. """ self._class_path = Path(__file__).parent.absolute() self.target_size = target_size self.min_face = min_face self.detector = Model(cfg=config_name) self.device = device def _preprocess(self, imgs): """ Preprocessing image before passing through the network. Resize and conversion to torch tensor. """ pp_imgs = [] for img in imgs: h0, w0 = img.shape[:2] # orig hw if self.target_size: r = self.target_size / min(h0, w0) # resize image to img_size if r < 1: img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size img = letterbox(img, new_shape=imgsz)[0] pp_imgs.append(img) pp_imgs = np.array(pp_imgs) pp_imgs = pp_imgs.transpose(0, 3, 1, 2) pp_imgs = torch.from_numpy(pp_imgs).to(self.device) pp_imgs = pp_imgs.float() # uint8 to fp16/32 return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): """ Postprocessing of raw pytorch model output. Returns: bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). """ bboxes = [[] for _ in range(len(origimgs))] landmarks = [[] for _ in range(len(origimgs))] pred = non_max_suppression_face(pred, conf_thres, iou_thres) for image_id, origimg in enumerate(origimgs): img_shape = origimg.shape image_height, image_width = img_shape[:2] gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks det = pred[image_id].cpu() scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() for j in range(det.size()[0]): box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() box = list( map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) ) if box[3] - box[1] < self.min_face: continue lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] bboxes[image_id].append(box) landmarks[image_id].append(lm) return bboxes, landmarks def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): """ Get bbox coordinates and keypoints of faces on original image. Params: imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) conf_thres: confidence threshold for each prediction iou_thres: threshold for NMS (filter of intersecting bboxes) Returns: bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). """ # Pass input images through face detector images = imgs if isinstance(imgs, list) else [imgs] images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] origimgs = copy.deepcopy(images) images = self._preprocess(images) if IS_HIGH_VERSION: with torch.inference_mode(): # for pytorch>=1.9 pred = self.detector(images)[0] else: with torch.no_grad(): # for pytorch<1.9 pred = self.detector(images)[0] bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) # return bboxes, points if not isListempty(points): bboxes = np.array(bboxes).reshape(-1,4) points = np.array(points).reshape(-1,10) padding = bboxes[:,0].reshape(-1,1) return np.concatenate((bboxes, padding, points), axis=1) else: return None def __call__(self, *args): return self.predict(*args) ================================================ FILE: modules/facelib/detection/yolov5face/models/__init__.py ================================================ ================================================ FILE: modules/facelib/detection/yolov5face/models/common.py ================================================ # This file contains modules common to various models import math import numpy as np import torch from torch import nn from ....detection.yolov5face.utils.datasets import letterbox from ....detection.yolov5face.utils.general import ( make_divisible, non_max_suppression, scale_coords, xyxy2xywh, ) def autopad(k, p=None): # kernel, padding # Pad to 'same' if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad return p def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc") # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten return x.view(batchsize, -1, height, width) def DWConv(c1, c2, k=1, s=1, act=True): # Depthwise convolution return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) class Conv(nn.Module): # Standard convolution def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) def forward(self, x): return self.act(self.bn(self.conv(x))) def fuseforward(self, x): return self.act(self.conv(x)) class StemBlock(nn.Module): def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True): super().__init__() self.stem_1 = Conv(c1, c2, k, s, p, g, act) self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0) self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1) self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0) def forward(self, x): stem_1_out = self.stem_1(x) stem_2a_out = self.stem_2a(stem_1_out) stem_2b_out = self.stem_2b(stem_2a_out) stem_2p_out = self.stem_2p(stem_1_out) return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1)) class Bottleneck(nn.Module): # Standard bottleneck def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_, c2, 3, 1, g=g) self.add = shortcut and c1 == c2 def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) class BottleneckCSP(nn.Module): # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) self.cv4 = Conv(2 * c_, c2, 1, 1) self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) self.act = nn.LeakyReLU(0.1, inplace=True) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): y1 = self.cv3(self.m(self.cv1(x))) y2 = self.cv2(x) return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) class C3(nn.Module): # CSP Bottleneck with 3 convolutions def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) class ShuffleV2Block(nn.Module): def __init__(self, inp, oup, stride): super().__init__() if not 1 <= stride <= 3: raise ValueError("illegal stride value") self.stride = stride branch_features = oup // 2 if self.stride > 1: self.branch1 = nn.Sequential( self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(inp), nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.SiLU(), ) else: self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( nn.Conv2d( inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False, ), nn.BatchNorm2d(branch_features), nn.SiLU(), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(branch_features), nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.SiLU(), ) @staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x): if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) else: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) out = channel_shuffle(out, 2) return out class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) def forward(self, x): x = self.cv1(x) return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) class Focus(nn.Module): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super().__init__() self.conv = Conv(c1 * 4, c2, k, s, p, g, act) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) class Concat(nn.Module): # Concatenate a list of tensors along dimension def __init__(self, dimension=1): super().__init__() self.d = dimension def forward(self, x): return torch.cat(x, self.d) class NMS(nn.Module): # Non-Maximum Suppression (NMS) module conf = 0.25 # confidence threshold iou = 0.45 # IoU threshold classes = None # (optional list) filter by class def forward(self, x): return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) class AutoShape(nn.Module): # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS img_size = 640 # inference size (pixels) conf = 0.25 # NMS confidence threshold iou = 0.45 # NMS IoU threshold classes = None # (optional list) filter by class def __init__(self, model): super().__init__() self.model = model.eval() def autoshape(self): print("autoShape already enabled, skipping... ") # model already converted to model.autoshape() return self def forward(self, imgs, size=640, augment=False, profile=False): # Inference from various sources. For height=720, width=1280, RGB images example inputs are: # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) # numpy: = np.zeros((720,1280,3)) # HWC # torch: = torch.zeros(16,3,720,1280) # BCHW # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images p = next(self.model.parameters()) # for device and type if isinstance(imgs, torch.Tensor): # torch return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference # Pre-process n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images shape0, shape1 = [], [] # image and inference shapes for i, im in enumerate(imgs): im = np.array(im) # to numpy if im.shape[0] < 5: # image in CHW im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input s = im.shape[:2] # HWC shape0.append(s) # image shape g = size / max(s) # gain shape1.append([y * g for y in s]) imgs[i] = im # update shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad x = np.stack(x, 0) if n > 1 else x[0][None] # stack x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32 # Inference with torch.no_grad(): y = self.model(x, augment, profile)[0] # forward y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS # Post-process for i in range(n): scale_coords(shape1, y[i][:, :4], shape0[i]) return Detections(imgs, y, self.names) class Detections: # detections class for YOLOv5 inference results def __init__(self, imgs, pred, names=None): super().__init__() d = pred[0].device # device gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations self.imgs = imgs # list of images as numpy arrays self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) self.names = names # class names self.xyxy = pred # xyxy pixels self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.n = len(self.pred) def __len__(self): return self.n def tolist(self): # return a list of Detections objects, i.e. 'for result in results.tolist():' x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] for d in x: for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]: setattr(d, k, getattr(d, k)[0]) # pop out of list return x ================================================ FILE: modules/facelib/detection/yolov5face/models/experimental.py ================================================ # # This file contains experimental modules import numpy as np import torch from torch import nn from ....detection.yolov5face.models.common import Conv class CrossConv(nn.Module): # Cross Convolution Downsample def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): # ch_in, ch_out, kernel, stride, groups, expansion, shortcut super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, (1, k), (1, s)) self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) self.add = shortcut and c1 == c2 def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) class MixConv2d(nn.Module): # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): super().__init__() groups = len(k) if equal_ch: # equal c_ per group i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices c_ = [(i == g).sum() for g in range(groups)] # intermediate channels else: # equal weight.numel() per group b = [c2] + [0] * groups a = np.eye(groups + 1, groups, k=-1) a -= np.roll(a, 1, axis=1) a *= np.array(k) ** 2 a[0] = 1 c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) self.bn = nn.BatchNorm2d(c2) self.act = nn.LeakyReLU(0.1, inplace=True) def forward(self, x): return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) ================================================ FILE: modules/facelib/detection/yolov5face/models/yolo.py ================================================ import math from copy import deepcopy from pathlib import Path import torch import yaml # for torch hub from torch import nn from ....detection.yolov5face.models.common import ( C3, NMS, SPP, AutoShape, Bottleneck, BottleneckCSP, Concat, Conv, DWConv, Focus, ShuffleV2Block, StemBlock, ) from ....detection.yolov5face.models.experimental import CrossConv, MixConv2d from ....detection.yolov5face.utils.autoanchor import check_anchor_order from ....detection.yolov5face.utils.general import make_divisible from ....detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn class Detect(nn.Module): stride = None # strides computed during build export = False # onnx export def __init__(self, nc=80, anchors=(), ch=()): # detection layer super().__init__() self.nc = nc # number of classes self.no = nc + 5 + 10 # number of outputs per anchor self.nl = len(anchors) # number of detection layers self.na = len(anchors[0]) // 2 # number of anchors self.grid = [torch.zeros(1)] * self.nl # init grid a = torch.tensor(anchors).float().view(self.nl, -1, 2) self.register_buffer("anchors", a) # shape(nl,na,2) self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv def forward(self, x): z = [] # inference output if self.export: for i in range(self.nl): x[i] = self.m[i](x[i]) return x for i in range(self.nl): x[i] = self.m[i](x[i]) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if not self.training: # inference if self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = torch.full_like(x[i], 0) y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() y[..., 5:15] = x[i][..., 5:15] y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh y[..., 5:7] = ( y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x1 y1 y[..., 7:9] = ( y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x2 y2 y[..., 9:11] = ( y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x3 y3 y[..., 11:13] = ( y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x4 y4 y[..., 13:15] = ( y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] ) # landmark x5 y5 z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1), x) @staticmethod def _make_grid(nx=20, ny=20): # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10 yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() class Model(nn.Module): def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes super().__init__() self.yaml_file = Path(cfg).name with Path(cfg).open(encoding="utf8") as f: self.yaml = yaml.safe_load(f) # model dict # Define model ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels if nc and nc != self.yaml["nc"]: self.yaml["nc"] = nc # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml["nc"])] # default names # Build strides, anchors m = self.model[-1] # Detect() if isinstance(m, Detect): s = 128 # 2x min stride m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.anchors /= m.stride.view(-1, 1, 1) check_anchor_order(m) self.stride = m.stride self._initialize_biases() # only run once def forward(self, x): return self.forward_once(x) # single-scale inference, train def forward_once(self, x): y = [] # outputs for m in self.model: if m.f != -1: # if not from previous layer x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers x = m(x) # run y.append(x if m.i in self.save else None) # save output return x def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency # https://arxiv.org/abs/1708.02002 section 3.3 m = self.model[-1] # Detect() module for mi, s in zip(m.m, m.stride): # from b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) def _print_biases(self): m = self.model[-1] # Detect() module for mi in m.m: # from b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers print("Fusing layers... ") for m in self.model.modules(): if isinstance(m, Conv) and hasattr(m, "bn"): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, "bn") # remove batchnorm m.forward = m.fuseforward # update forward elif type(m) is nn.Upsample: m.recompute_scale_factor = None # torch 1.11.0 compatibility return self def nms(self, mode=True): # add or remove NMS module present = isinstance(self.model[-1], NMS) # last layer is NMS if mode and not present: print("Adding NMS... ") m = NMS() # module m.f = -1 # from m.i = self.model[-1].i + 1 # index self.model.add_module(name=str(m.i), module=m) # add self.eval() elif not mode and present: print("Removing NMS... ") self.model = self.model[:-1] # remove return self def autoshape(self): # add autoShape module print("Adding autoShape... ") m = AutoShape(self) # wrap model copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes return m def parse_model(d, ch): # model_dict, input_channels(3) anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors no = na * (nc + 5) # number of outputs = anchors * (classes + 5) layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args m = eval(m) if isinstance(m, str) else m # eval strings for j, a in enumerate(args): try: args[j] = eval(a) if isinstance(a, str) else a # eval strings except Exception: pass n = max(round(n * gd), 1) if n > 1 else n # depth gain if m in [ Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, ShuffleV2Block, StemBlock, ]: c1, c2 = ch[f], args[0] c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 args = [c1, c2, *args[1:]] if m in [BottleneckCSP, C3]: args.insert(2, n) n = 1 elif m is nn.BatchNorm2d: args = [ch[f]] elif m is Concat: c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) elif m is Detect: args.append([ch[x + 1] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) else: c2 = ch[f] m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module t = str(m)[8:-2].replace("__main__.", "") # module type np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) ch.append(c2) return nn.Sequential(*layers), sorted(save) ================================================ FILE: modules/facelib/detection/yolov5face/models/yolov5l.yaml ================================================ # parameters nc: 1 # number of classes depth_multiple: 1.0 # model depth multiple width_multiple: 1.0 # layer channel multiple # anchors anchors: - [4,5, 8,10, 13,16] # P3/8 - [23,29, 43,55, 73,105] # P4/16 - [146,217, 231,300, 335,433] # P5/32 # YOLOv5 backbone backbone: # [from, number, module, args] [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 [-1, 3, C3, [128]], [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 [-1, 9, C3, [256]], [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 [-1, 9, C3, [512]], [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 [-1, 1, SPP, [1024, [3,5,7]]], [-1, 3, C3, [1024, False]], # 8 ] # YOLOv5 head head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 5], 1, Concat, [1]], # cat backbone P4 [-1, 3, C3, [512, False]], # 12 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 3], 1, Concat, [1]], # cat backbone P3 [-1, 3, C3, [256, False]], # 16 (P3/8-small) [-1, 1, Conv, [256, 3, 2]], [[-1, 13], 1, Concat, [1]], # cat head P4 [-1, 3, C3, [512, False]], # 19 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 9], 1, Concat, [1]], # cat head P5 [-1, 3, C3, [1024, False]], # 22 (P5/32-large) [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] ================================================ FILE: modules/facelib/detection/yolov5face/models/yolov5n.yaml ================================================ # parameters nc: 1 # number of classes depth_multiple: 1.0 # model depth multiple width_multiple: 1.0 # layer channel multiple # anchors anchors: - [4,5, 8,10, 13,16] # P3/8 - [23,29, 43,55, 73,105] # P4/16 - [146,217, 231,300, 335,433] # P5/32 # YOLOv5 backbone backbone: # [from, number, module, args] [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 [-1, 3, ShuffleV2Block, [128, 1]], # 2 [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 [-1, 7, ShuffleV2Block, [256, 1]], # 4 [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 [-1, 3, ShuffleV2Block, [512, 1]], # 6 ] # YOLOv5 head head: [[-1, 1, Conv, [128, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P4 [-1, 1, C3, [128, False]], # 10 [-1, 1, Conv, [128, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 2], 1, Concat, [1]], # cat backbone P3 [-1, 1, C3, [128, False]], # 14 (P3/8-small) [-1, 1, Conv, [128, 3, 2]], [[-1, 11], 1, Concat, [1]], # cat head P4 [-1, 1, C3, [128, False]], # 17 (P4/16-medium) [-1, 1, Conv, [128, 3, 2]], [[-1, 7], 1, Concat, [1]], # cat head P5 [-1, 1, C3, [128, False]], # 20 (P5/32-large) [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] ================================================ FILE: modules/facelib/detection/yolov5face/utils/__init__.py ================================================ ================================================ FILE: modules/facelib/detection/yolov5face/utils/autoanchor.py ================================================ # Auto-anchor utils def check_anchor_order(m): # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary a = m.anchor_grid.prod(-1).view(-1) # anchor area da = a[-1] - a[0] # delta a ds = m.stride[-1] - m.stride[0] # delta s if da.sign() != ds.sign(): # same order print("Reversing anchor order") m.anchors[:] = m.anchors.flip(0) m.anchor_grid[:] = m.anchor_grid.flip(0) ================================================ FILE: modules/facelib/detection/yolov5face/utils/datasets.py ================================================ import cv2 import numpy as np def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 shape = img.shape[:2] # current shape [height, width] if isinstance(new_shape, int): new_shape = (new_shape, new_shape) # Scale ratio (new / old) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) if not scaleup: # only scale down, do not scale up (for better test mAP) r = min(r, 1.0) # Compute padding ratio = r, r # width, height ratios new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding if auto: # minimum rectangle dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding elif scale_fill: # stretch dw, dh = 0.0, 0.0 new_unpad = (new_shape[1], new_shape[0]) ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios dw /= 2 # divide padding into 2 sides dh /= 2 if shape[::-1] != new_unpad: # resize img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border return img, ratio, (dw, dh) ================================================ FILE: modules/facelib/detection/yolov5face/utils/extract_ckpt.py ================================================ import torch import sys sys.path.insert(0,'./facelib/detection/yolov5face') model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') ================================================ FILE: modules/facelib/detection/yolov5face/utils/general.py ================================================ import math import time import numpy as np import torch import torchvision def check_img_size(img_size, s=32): # Verify img_size is a multiple of stride s new_size = make_divisible(img_size, int(s)) # ceil gs-multiple # if new_size != img_size: # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") return new_size def make_divisible(x, divisor): # Returns x evenly divisible by divisor return math.ceil(x / divisor) * divisor def xyxy2xywh(x): # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center y[:, 2] = x[:, 2] - x[:, 0] # width y[:, 3] = x[:, 3] - x[:, 1] # height return y def xywh2xyxy(x): # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y return y def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): # Rescale coords (xyxy) from img1_shape to img0_shape if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding else: gain = ratio_pad[0][0] pad = ratio_pad[1] coords[:, [0, 2]] -= pad[0] # x padding coords[:, [1, 3]] -= pad[1] # y padding coords[:, :4] /= gain clip_coords(coords, img0_shape) return coords def clip_coords(boxes, img_shape): # Clip bounding xyxy bounding boxes to image shape (height, width) boxes[:, 0].clamp_(0, img_shape[1]) # x1 boxes[:, 1].clamp_(0, img_shape[0]) # y1 boxes[:, 2].clamp_(0, img_shape[1]) # x2 boxes[:, 3].clamp_(0, img_shape[0]) # y2 def box_iou(box1, box2): # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py """ Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format. Arguments: box1 (Tensor[N, 4]) box2 (Tensor[M, 4]) Returns: iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ def box_area(box): return (box[2] - box[0]) * (box[3] - box[1]) area1 = box_area(box1.T) area2 = box_area(box2.T) inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) return inter / (area1[:, None] + area2 - inter) def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): """Performs Non-Maximum Suppression (NMS) on inference results Returns: detections with shape: nx6 (x1, y1, x2, y2, conf, cls) """ nc = prediction.shape[2] - 15 # number of classes xc = prediction[..., 4] > conf_thres # candidates # Settings # (pixels) maximum box width and height max_wh = 4096 time_limit = 10.0 # seconds to quit after redundant = True # require redundant detections multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS t = time.time() output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] for xi, x in enumerate(prediction): # image index, image inference # Apply constraints x = x[xc[xi]] # confidence # Cat apriori labels if autolabelling if labels and len(labels[xi]): label = labels[xi] v = torch.zeros((len(label), nc + 15), device=x.device) v[:, :4] = label[:, 1:5] # box v[:, 4] = 1.0 # conf v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls x = torch.cat((x, v), 0) # If none remain process next image if not x.shape[0]: continue # Compute conf x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) # Detections matrix nx6 (xyxy, conf, landmarks, cls) if multi_label: i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) else: # best class only conf, j = x[:, 15:].max(1, keepdim=True) x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] # Filter by class if classes is not None: x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] # If none remain process next image n = x.shape[0] # number of boxes if not n: continue # Batched NMS c = x[:, 15:16] * (0 if agnostic else max_wh) # classes boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix weights = iou * scores[None] # box weights x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes if redundant: i = i[iou.sum(1) > 1] # require redundancy output[xi] = x[i] if (time.time() - t) > time_limit: break # time limit exceeded return output def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): """Performs Non-Maximum Suppression (NMS) on inference results Returns: detections with shape: nx6 (x1, y1, x2, y2, conf, cls) """ nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates # Settings # (pixels) maximum box width and height max_wh = 4096 time_limit = 10.0 # seconds to quit after redundant = True # require redundant detections multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS t = time.time() output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] for xi, x in enumerate(prediction): # image index, image inference x = x[xc[xi]] # confidence # Cat apriori labels if autolabelling if labels and len(labels[xi]): label_id = labels[xi] v = torch.zeros((len(label_id), nc + 5), device=x.device) v[:, :4] = label_id[:, 1:5] # box v[:, 4] = 1.0 # conf v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls x = torch.cat((x, v), 0) # If none remain process next image if not x.shape[0]: continue # Compute conf x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) # Detections matrix nx6 (xyxy, conf, cls) if multi_label: i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) else: # best class only conf, j = x[:, 5:].max(1, keepdim=True) x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] # Filter by class if classes is not None: x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] # Check shape n = x.shape[0] # number of boxes if not n: # no boxes continue x = x[x[:, 4].argsort(descending=True)] # sort by confidence # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix weights = iou * scores[None] # box weights x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes if redundant: i = i[iou.sum(1) > 1] # require redundancy output[xi] = x[i] if (time.time() - t) > time_limit: print(f"WARNING: NMS time limit {time_limit}s exceeded") break # time limit exceeded return output def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): # Rescale coords (xyxy) from img1_shape to img0_shape if ratio_pad is None: # calculate from img0_shape gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding else: gain = ratio_pad[0][0] pad = ratio_pad[1] coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding coords[:, :10] /= gain coords[:, 0].clamp_(0, img0_shape[1]) # x1 coords[:, 1].clamp_(0, img0_shape[0]) # y1 coords[:, 2].clamp_(0, img0_shape[1]) # x2 coords[:, 3].clamp_(0, img0_shape[0]) # y2 coords[:, 4].clamp_(0, img0_shape[1]) # x3 coords[:, 5].clamp_(0, img0_shape[0]) # y3 coords[:, 6].clamp_(0, img0_shape[1]) # x4 coords[:, 7].clamp_(0, img0_shape[0]) # y4 coords[:, 8].clamp_(0, img0_shape[1]) # x5 coords[:, 9].clamp_(0, img0_shape[0]) # y5 return coords ================================================ FILE: modules/facelib/detection/yolov5face/utils/torch_utils.py ================================================ import torch from torch import nn def fuse_conv_and_bn(conv, bn): # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = ( nn.Conv2d( conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, groups=conv.groups, bias=True, ) .requires_grad_(False) .to(conv.weight.device) ) # prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) # prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) return fusedconv def copy_attr(a, b, include=(), exclude=()): # Copy attributes from b to a, options to only include [...] and to exclude [...] for k, v in b.__dict__.items(): if (include and k not in include) or k.startswith("_") or k in exclude: continue setattr(a, k, v) ================================================ FILE: modules/facelib/parsing/__init__.py ================================================ import os import torch from ..utils import load_file_from_url from .bisenet import BiSeNet from .parsenet import ParseNet from modules import paths model_dir = os.path.join(paths.models_path, 'Codeformer') def init_parsing_model(model_name='bisenet', half=False, device='cuda'): if model_name == 'bisenet': model = BiSeNet(num_class=19) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' elif model_name == 'parsenet': model = ParseNet(in_size=512, out_size=512, parsing_ch=19) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' else: raise NotImplementedError(f'{model_name} is not implemented.') model_path = load_file_from_url(url=model_url, model_dir=model_dir, progress=True, file_name=None) load_net = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(load_net, strict=True) model.eval() model = model.to(device) return model ================================================ FILE: modules/facelib/parsing/bisenet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .resnet import ResNet18 class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm2d(out_chan) def forward(self, x): x = self.conv(x) x = F.relu(self.bn(x)) return x class BiSeNetOutput(nn.Module): def __init__(self, in_chan, mid_chan, num_class): super(BiSeNetOutput, self).__init__() self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) def forward(self, x): feat = self.conv(x) out = self.conv_out(feat) return out, feat class AttentionRefinementModule(nn.Module): def __init__(self, in_chan, out_chan): super(AttentionRefinementModule, self).__init__() self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) self.bn_atten = nn.BatchNorm2d(out_chan) self.sigmoid_atten = nn.Sigmoid() def forward(self, x): feat = self.conv(x) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv_atten(atten) atten = self.bn_atten(atten) atten = self.sigmoid_atten(atten) out = torch.mul(feat, atten) return out class ContextPath(nn.Module): def __init__(self): super(ContextPath, self).__init__() self.resnet = ResNet18() self.arm16 = AttentionRefinementModule(256, 128) self.arm32 = AttentionRefinementModule(512, 128) self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) def forward(self, x): feat8, feat16, feat32 = self.resnet(x) h8, w8 = feat8.size()[2:] h16, w16 = feat16.size()[2:] h32, w32 = feat32.size()[2:] avg = F.avg_pool2d(feat32, feat32.size()[2:]) avg = self.conv_avg(avg) avg_up = F.interpolate(avg, (h32, w32), mode='nearest') feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg_up feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') feat16_up = self.conv_head16(feat16_up) return feat8, feat16_up, feat32_up # x8, x8, x16 class FeatureFusionModule(nn.Module): def __init__(self, in_chan, out_chan): super(FeatureFusionModule, self).__init__() self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) self.relu = nn.ReLU(inplace=True) self.sigmoid = nn.Sigmoid() def forward(self, fsp, fcp): fcat = torch.cat([fsp, fcp], dim=1) feat = self.convblk(fcat) atten = F.avg_pool2d(feat, feat.size()[2:]) atten = self.conv1(atten) atten = self.relu(atten) atten = self.conv2(atten) atten = self.sigmoid(atten) feat_atten = torch.mul(feat, atten) feat_out = feat_atten + feat return feat_out class BiSeNet(nn.Module): def __init__(self, num_class): super(BiSeNet, self).__init__() self.cp = ContextPath() self.ffm = FeatureFusionModule(256, 256) self.conv_out = BiSeNetOutput(256, 256, num_class) self.conv_out16 = BiSeNetOutput(128, 64, num_class) self.conv_out32 = BiSeNetOutput(128, 64, num_class) def forward(self, x, return_feat=False): h, w = x.size()[2:] feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature feat_sp = feat_res8 # replace spatial path feature with res3b1 feature feat_fuse = self.ffm(feat_sp, feat_cp8) out, feat = self.conv_out(feat_fuse) out16, feat16 = self.conv_out16(feat_cp8) out32, feat32 = self.conv_out32(feat_cp16) out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) if return_feat: feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) return out, out16, out32, feat, feat16, feat32 else: return out, out16, out32 ================================================ FILE: modules/facelib/parsing/parsenet.py ================================================ """Modified from https://github.com/chaofengc/PSFRGAN """ import numpy as np import torch.nn as nn from torch.nn import functional as F class NormLayer(nn.Module): """Normalization Layers. Args: channels: input channels, for batch norm and instance norm. input_size: input shape without batch size, for layer norm. """ def __init__(self, channels, normalize_shape=None, norm_type='bn'): super(NormLayer, self).__init__() norm_type = norm_type.lower() self.norm_type = norm_type if norm_type == 'bn': self.norm = nn.BatchNorm2d(channels, affine=True) elif norm_type == 'in': self.norm = nn.InstanceNorm2d(channels, affine=False) elif norm_type == 'gn': self.norm = nn.GroupNorm(32, channels, affine=True) elif norm_type == 'pixel': self.norm = lambda x: F.normalize(x, p=2, dim=1) elif norm_type == 'layer': self.norm = nn.LayerNorm(normalize_shape) elif norm_type == 'none': self.norm = lambda x: x * 1.0 else: assert 1 == 0, f'Norm type {norm_type} not support.' def forward(self, x, ref=None): if self.norm_type == 'spade': return self.norm(x, ref) else: return self.norm(x) class ReluLayer(nn.Module): """Relu Layer. Args: relu type: type of relu layer, candidates are - ReLU - LeakyReLU: default relu slope 0.2 - PRelu - SELU - none: direct pass """ def __init__(self, channels, relu_type='relu'): super(ReluLayer, self).__init__() relu_type = relu_type.lower() if relu_type == 'relu': self.func = nn.ReLU(True) elif relu_type == 'leakyrelu': self.func = nn.LeakyReLU(0.2, inplace=True) elif relu_type == 'prelu': self.func = nn.PReLU(channels) elif relu_type == 'selu': self.func = nn.SELU(True) elif relu_type == 'none': self.func = lambda x: x * 1.0 else: assert 1 == 0, f'Relu type {relu_type} not support.' def forward(self, x): return self.func(x) class ConvLayer(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, scale='none', norm_type='none', relu_type='none', use_pad=True, bias=True): super(ConvLayer, self).__init__() self.use_pad = use_pad self.norm_type = norm_type if norm_type in ['bn']: bias = False stride = 2 if scale == 'down' else 1 self.scale_func = lambda x: x if scale == 'up': self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) self.relu = ReluLayer(out_channels, relu_type) self.norm = NormLayer(out_channels, norm_type=norm_type) def forward(self, x): out = self.scale_func(x) if self.use_pad: out = self.reflection_pad(out) out = self.conv2d(out) out = self.norm(out) out = self.relu(out) return out class ResidualBlock(nn.Module): """ Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html """ def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): super(ResidualBlock, self).__init__() if scale == 'none' and c_in == c_out: self.shortcut_func = lambda x: x else: self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} scale_conf = scale_config_dict[scale] self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') def forward(self, x): identity = self.shortcut_func(x) res = self.conv1(x) res = self.conv2(res) return identity + res class ParseNet(nn.Module): def __init__(self, in_size=128, out_size=128, min_feat_size=32, base_ch=64, parsing_ch=19, res_depth=10, relu_type='LeakyReLU', norm_type='bn', ch_range=[32, 256]): super().__init__() self.res_depth = res_depth act_args = {'norm_type': norm_type, 'relu_type': relu_type} min_ch, max_ch = ch_range ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 min_feat_size = min(in_size, min_feat_size) down_steps = int(np.log2(in_size // min_feat_size)) up_steps = int(np.log2(out_size // min_feat_size)) # =============== define encoder-body-decoder ==================== self.encoder = [] self.encoder.append(ConvLayer(3, base_ch, 3, 1)) head_ch = base_ch for i in range(down_steps): cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) head_ch = head_ch * 2 self.body = [] for i in range(res_depth): self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) self.decoder = [] for i in range(up_steps): cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) head_ch = head_ch // 2 self.encoder = nn.Sequential(*self.encoder) self.body = nn.Sequential(*self.body) self.decoder = nn.Sequential(*self.decoder) self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) def forward(self, x): feat = self.encoder(x) x = feat + self.body(feat) x = self.decoder(x) out_img = self.out_img_conv(x) out_mask = self.out_mask_conv(x) return out_mask, out_img ================================================ FILE: modules/facelib/parsing/resnet.py ================================================ import torch.nn as nn import torch.nn.functional as F def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): def __init__(self, in_chan, out_chan, stride=1): super(BasicBlock, self).__init__() self.conv1 = conv3x3(in_chan, out_chan, stride) self.bn1 = nn.BatchNorm2d(out_chan) self.conv2 = conv3x3(out_chan, out_chan) self.bn2 = nn.BatchNorm2d(out_chan) self.relu = nn.ReLU(inplace=True) self.downsample = None if in_chan != out_chan or stride != 1: self.downsample = nn.Sequential( nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_chan), ) def forward(self, x): residual = self.conv1(x) residual = F.relu(self.bn1(residual)) residual = self.conv2(residual) residual = self.bn2(residual) shortcut = x if self.downsample is not None: shortcut = self.downsample(x) out = shortcut + residual out = self.relu(out) return out def create_layer_basic(in_chan, out_chan, bnum, stride=1): layers = [BasicBlock(in_chan, out_chan, stride=stride)] for i in range(bnum - 1): layers.append(BasicBlock(out_chan, out_chan, stride=1)) return nn.Sequential(*layers) class ResNet18(nn.Module): def __init__(self): super(ResNet18, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) def forward(self, x): x = self.conv1(x) x = F.relu(self.bn1(x)) x = self.maxpool(x) x = self.layer1(x) feat8 = self.layer2(x) # 1/8 feat16 = self.layer3(feat8) # 1/16 feat32 = self.layer4(feat16) # 1/32 return feat8, feat16, feat32 ================================================ FILE: modules/facelib/utils/__init__.py ================================================ from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir __all__ = [ 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' ] ================================================ FILE: modules/facelib/utils/face_restoration_helper.py ================================================ import cv2 import numpy as np import os import torch from torchvision.transforms.functional import normalize from ..detection import init_detection_model from ..parsing import init_parsing_model from ..utils.misc import img2tensor, imwrite, is_gray, bgr2gray def get_largest_face(det_faces, h, w): def get_location(val, length): if val < 0: return 0 elif val > length: return length else: return val face_areas = [] for det_face in det_faces: left = get_location(det_face[0], w) right = get_location(det_face[2], w) top = get_location(det_face[1], h) bottom = get_location(det_face[3], h) face_area = (right - left) * (bottom - top) face_areas.append(face_area) largest_idx = face_areas.index(max(face_areas)) return det_faces[largest_idx], largest_idx def get_center_face(det_faces, h=0, w=0, center=None): if center is not None: center = np.array(center) else: center = np.array([w / 2, h / 2]) center_dist = [] for det_face in det_faces: face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) dist = np.linalg.norm(face_center - center) center_dist.append(dist) center_idx = center_dist.index(min(center_dist)) return det_faces[center_idx], center_idx class FaceRestoreHelper(object): """Helper for the face restoration pipeline (base class).""" def __init__(self, upscale_factor, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', template_3points=False, pad_blur=False, use_parse=False, device=None): self.template_3points = template_3points # improve robustness self.upscale_factor = int(upscale_factor) # the cropped face ratio based on the square face self.crop_ratio = crop_ratio # (h, w) assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) if self.template_3points: self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) else: # standard 5 landmarks for FFHQ faces with 512 x 512 # facexlib self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], [201.26117, 371.41043], [313.08905, 371.15118]]) # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], # [198.22603, 372.82502], [313.91018, 372.75659]]) self.face_template = self.face_template * (face_size / 512.0) if self.crop_ratio[0] > 1: self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 if self.crop_ratio[1] > 1: self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 self.save_ext = save_ext self.pad_blur = pad_blur if self.pad_blur is True: self.template_3points = False self.all_landmarks_5 = [] self.det_faces = [] self.affine_matrices = [] self.inverse_affine_matrices = [] self.cropped_faces = [] self.restored_faces = [] self.pad_input_imgs = [] if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device # init face detection model self.face_det = init_detection_model(det_model, half=False, device=self.device) # init face parsing model self.use_parse = use_parse self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) def set_upscale_factor(self, upscale_factor): self.upscale_factor = upscale_factor def read_image(self, img): """img can be image path or cv2 loaded image.""" # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] if isinstance(img, str): img = cv2.imread(img) if np.max(img) > 256: # 16-bit image img = img / 65535 * 255 if len(img.shape) == 2: # gray image img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif img.shape[2] == 4: # BGRA image with alpha channel img = img[:, :, 0:3] self.input_img = img self.is_gray = is_gray(img, threshold=5) if self.is_gray: print('Grayscale input: True') if min(self.input_img.shape[:2])<512: f = 512.0/min(self.input_img.shape[:2]) self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) def get_face_landmarks_5(self, only_keep_largest=False, only_center_face=False, resize=None, blur_ratio=0.01, eye_dist_threshold=None): if resize is None: scale = 1 input_img = self.input_img else: h, w = self.input_img.shape[0:2] scale = resize / min(h, w) scale = max(1, scale) # always scale up h, w = int(h * scale), int(w * scale) interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) with torch.no_grad(): bboxes = self.face_det.detect_faces(input_img) if bboxes is None or bboxes.shape[0] == 0: return 0 else: bboxes = bboxes / scale for bbox in bboxes: # remove faces with too small eye distance: side faces or too small faces eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): continue if self.template_3points: landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) else: landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) self.all_landmarks_5.append(landmark) self.det_faces.append(bbox[0:5]) if len(self.det_faces) == 0: return 0 if only_keep_largest: h, w, _ = self.input_img.shape self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] elif only_center_face: h, w, _ = self.input_img.shape self.det_faces, center_idx = get_center_face(self.det_faces, h, w) self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] # pad blurry images if self.pad_blur: self.pad_input_imgs = [] for landmarks in self.all_landmarks_5: # get landmarks eye_left = landmarks[0, :] eye_right = landmarks[1, :] eye_avg = (eye_left + eye_right) * 0.5 mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 eye_to_eye = eye_right - eye_left eye_to_mouth = mouth_avg - eye_avg # Get the oriented crop rectangle # x: half width of the oriented crop rectangle x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise # norm with the hypotenuse: get the direction x /= np.hypot(*x) # get the hypotenuse of a right triangle rect_scale = 1.5 x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) # y: half height of the oriented crop rectangle y = np.flipud(x) * [-1, 1] # c: center c = eye_avg + eye_to_mouth * 0.1 # quad: (left_top, left_bottom, right_bottom, right_top) quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # qsize: side length of the square qsize = np.hypot(*x) * 2 border = max(int(np.rint(qsize * 0.1)), 3) # get pad # pad: (width_left, height_top, width_right, height_bottom) pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) pad = [ max(-pad[0] + border, 1), max(-pad[1] + border, 1), max(pad[2] - self.input_img.shape[0] + border, 1), max(pad[3] - self.input_img.shape[1] + border, 1) ] if max(pad) > 1: # pad image pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') # modify landmark coords landmarks[:, 0] += pad[0] landmarks[:, 1] += pad[1] # blur pad images h, w, _ = pad_img.shape y, x, _ = np.ogrid[:h, :w, :1] mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) blur = int(qsize * blur_ratio) if blur % 2 == 0: blur += 1 blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) pad_img = pad_img.astype('float32') pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] self.pad_input_imgs.append(pad_img) else: self.pad_input_imgs.append(np.copy(self.input_img)) return len(self.all_landmarks_5) def align_warp_face(self, save_cropped_path=None, border_mode='constant'): """Align and warp faces with face template. """ if self.pad_blur: assert len(self.pad_input_imgs) == len( self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' for idx, landmark in enumerate(self.all_landmarks_5): # use 5 landmarks to get affine matrix # use cv2.LMEDS method for the equivalence to skimage transform # ref: https://blog.csdn.net/yichxi/article/details/115827338 affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] self.affine_matrices.append(affine_matrix) # warp and crop faces if border_mode == 'constant': border_mode = cv2.BORDER_CONSTANT elif border_mode == 'reflect101': border_mode = cv2.BORDER_REFLECT101 elif border_mode == 'reflect': border_mode = cv2.BORDER_REFLECT if self.pad_blur: input_img = self.pad_input_imgs[idx] else: input_img = self.input_img cropped_face = cv2.warpAffine( input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray self.cropped_faces.append(cropped_face) # save the cropped face if save_cropped_path is not None: path = os.path.splitext(save_cropped_path)[0] save_path = f'{path}_{idx:02d}.{self.save_ext}' imwrite(cropped_face, save_path) def get_inverse_affine(self, save_inverse_affine_path=None): """Get inverse affine matrix.""" for idx, affine_matrix in enumerate(self.affine_matrices): inverse_affine = cv2.invertAffineTransform(affine_matrix) inverse_affine *= self.upscale_factor self.inverse_affine_matrices.append(inverse_affine) # save inverse affine matrices if save_inverse_affine_path is not None: path, _ = os.path.splitext(save_inverse_affine_path) save_path = f'{path}_{idx:02d}.pth' torch.save(inverse_affine, save_path) def add_restored_face(self, face): if self.is_gray: face = bgr2gray(face) # convert img into grayscale self.restored_faces.append(face) def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): h, w, _ = self.input_img.shape h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) if upsample_img is None: # simply resize the background # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) else: upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) assert len(self.restored_faces) == len( self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') inv_mask_borders = [] for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): if face_upsampler is not None: restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] inverse_affine /= self.upscale_factor inverse_affine[:, 2] *= self.upscale_factor face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) else: # Add an offset to inverse affine matrix, for more precise back alignment if self.upscale_factor > 1: extra_offset = 0.5 * self.upscale_factor else: extra_offset = 0 inverse_affine[:, 2] += extra_offset face_size = self.face_size inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) # if draw_box or not self.use_parse: # use square parse maps # mask = np.ones(face_size, dtype=np.float32) # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) # # remove the black borders # inv_mask_erosion = cv2.erode( # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) # pasted_face = inv_mask_erosion[:, :, None] * inv_restored # total_face_area = np.sum(inv_mask_erosion) # // 3 # # add border # if draw_box: # h, w = face_size # mask_border = np.ones((h, w, 3), dtype=np.float32) # border = int(1400/np.sqrt(total_face_area)) # mask_border[border:h-border, border:w-border,:] = 0 # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) # inv_mask_borders.append(inv_mask_border) # if not self.use_parse: # # compute the fusion edge based on the area of face # w_edge = int(total_face_area**0.5) // 20 # erosion_radius = w_edge * 2 # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) # blur_size = w_edge * 2 # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) # if len(upsample_img.shape) == 2: # upsample_img is gray image # upsample_img = upsample_img[:, :, None] # inv_soft_mask = inv_soft_mask[:, :, None] # always use square mask mask = np.ones(face_size, dtype=np.float32) inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) # remove the black borders inv_mask_erosion = cv2.erode( inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) pasted_face = inv_mask_erosion[:, :, None] * inv_restored total_face_area = np.sum(inv_mask_erosion) # // 3 # add border if draw_box: h, w = face_size mask_border = np.ones((h, w, 3), dtype=np.float32) border = int(1400/np.sqrt(total_face_area)) mask_border[border:h-border, border:w-border,:] = 0 inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) inv_mask_borders.append(inv_mask_border) # compute the fusion edge based on the area of face w_edge = int(total_face_area**0.5) // 20 erosion_radius = w_edge * 2 inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) blur_size = w_edge * 2 inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) if len(upsample_img.shape) == 2: # upsample_img is gray image upsample_img = upsample_img[:, :, None] inv_soft_mask = inv_soft_mask[:, :, None] # parse mask if self.use_parse: # inference face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) face_input = torch.unsqueeze(face_input, 0).to(self.device) with torch.no_grad(): out = self.face_parse(face_input)[0] out = out.argmax(dim=1).squeeze().cpu().numpy() parse_mask = np.zeros(out.shape) MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] for idx, color in enumerate(MASK_COLORMAP): parse_mask[out == idx] = color # blur the mask parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) # remove the black borders thres = 10 parse_mask[:thres, :] = 0 parse_mask[-thres:, :] = 0 parse_mask[:, :thres] = 0 parse_mask[:, -thres:] = 0 parse_mask = parse_mask / 255. parse_mask = cv2.resize(parse_mask, face_size) parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) inv_soft_parse_mask = parse_mask[:, :, None] # pasted_face = inv_restored fuse_mask = (inv_soft_parse_mask 256: # 16-bit image upsample_img = upsample_img.astype(np.uint16) else: upsample_img = upsample_img.astype(np.uint8) # draw bounding box if draw_box: # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) img_color = np.ones([*upsample_img.shape], dtype=np.float32) img_color[:,:,0] = 0 img_color[:,:,1] = 255 img_color[:,:,2] = 0 for inv_mask_border in inv_mask_borders: upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img if save_path is not None: path = os.path.splitext(save_path)[0] save_path = f'{path}.{self.save_ext}' imwrite(upsample_img, save_path) return upsample_img def clean_all(self): self.all_landmarks_5 = [] self.restored_faces = [] self.affine_matrices = [] self.cropped_faces = [] self.inverse_affine_matrices = [] self.det_faces = [] self.pad_input_imgs = [] ================================================ FILE: modules/facelib/utils/face_utils.py ================================================ import os import cv2 import numpy as np import torch def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): left, top, right, bot = bbox width = right - left height = bot - top if preserve_aspect: width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) else: width_increase = height_increase = increase_area left = int(left - width_increase * width) top = int(top - height_increase * height) right = int(right + width_increase * width) bot = int(bot + height_increase * height) return (left, top, right, bot) def get_valid_bboxes(bboxes, h, w): left = max(bboxes[0], 0) top = max(bboxes[1], 0) right = min(bboxes[2], w) bottom = min(bboxes[3], h) return (left, top, right, bottom) def align_crop_face_landmarks(img, landmarks, output_size, transform_size=None, enable_padding=True, return_inverse_affine=False, shrink_ratio=(1, 1)): """Align and crop face with landmarks. The output_size and transform_size are based on width. The height is adjusted based on shrink_ratio_h/shring_ration_w. Modified from: https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py Args: img (Numpy array): Input image. landmarks (Numpy array): 5 or 68 or 98 landmarks. output_size (int): Output face size. transform_size (ing): Transform size. Usually the four time of output_size. enable_padding (float): Default: True. shrink_ratio (float | tuple[float] | list[float]): Shring the whole face for height and width (crop larger area). Default: (1, 1). Returns: (Numpy array): Cropped face. """ lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 if isinstance(shrink_ratio, (float, int)): shrink_ratio = (shrink_ratio, shrink_ratio) if transform_size is None: transform_size = output_size * 4 # Parse landmarks lm = np.array(landmarks) if lm.shape[0] == 5 and lm_type == 'retinaface_5': eye_left = lm[0] eye_right = lm[1] mouth_avg = (lm[3] + lm[4]) * 0.5 elif lm.shape[0] == 5 and lm_type == 'dlib_5': lm_eye_left = lm[2:4] lm_eye_right = lm[0:2] eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) mouth_avg = lm[4] elif lm.shape[0] == 68: lm_eye_left = lm[36:42] lm_eye_right = lm[42:48] eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) mouth_avg = (lm[48] + lm[54]) * 0.5 elif lm.shape[0] == 98: lm_eye_left = lm[60:68] lm_eye_right = lm[68:76] eye_left = np.mean(lm_eye_left, axis=0) eye_right = np.mean(lm_eye_right, axis=0) mouth_avg = (lm[76] + lm[82]) * 0.5 eye_avg = (eye_left + eye_right) * 0.5 eye_to_eye = eye_right - eye_left eye_to_mouth = mouth_avg - eye_avg # Get the oriented crop rectangle # x: half width of the oriented crop rectangle x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise # norm with the hypotenuse: get the direction x /= np.hypot(*x) # get the hypotenuse of a right triangle rect_scale = 1 x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) # y: half height of the oriented crop rectangle y = np.flipud(x) * [-1, 1] x *= shrink_ratio[1] # width y *= shrink_ratio[0] # height # c: center c = eye_avg + eye_to_mouth * 0.1 # quad: (left_top, left_bottom, right_bottom, right_top) quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # qsize: side length of the square qsize = np.hypot(*x) * 2 quad_ori = np.copy(quad) # Shrink, for large face shrink = int(np.floor(qsize / output_size * 0.5)) if shrink > 1: h, w = img.shape[0:2] rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) quad /= shrink qsize /= shrink # Crop h, w = img.shape[0:2] border = max(int(np.rint(qsize * 0.1)), 3) crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) if crop[2] - crop[0] < w or crop[3] - crop[1] < h: img = img[crop[1]:crop[3], crop[0]:crop[2], :] quad -= crop[0:2] # Pad # pad: (width_left, height_top, width_right, height_bottom) h, w = img.shape[0:2] pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) if enable_padding and max(pad) > border - 4: pad = np.maximum(pad, int(np.rint(qsize * 0.3))) img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') h, w = img.shape[0:2] y, x, _ = np.ogrid[:h, :w, :1] mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) blur = int(qsize * 0.02) if blur % 2 == 0: blur += 1 blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) img = img.astype('float32') img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) img = np.clip(img, 0, 255) # float32, [0, 255] quad += pad[:2] # Transform use cv2 h_ratio = shrink_ratio[0] / shrink_ratio[1] dst_h, dst_w = int(transform_size * h_ratio), transform_size template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) # use cv2.LMEDS method for the equivalence to skimage transform # ref: https://blog.csdn.net/yichxi/article/details/115827338 affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] cropped_face = cv2.warpAffine( img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray if output_size < transform_size: cropped_face = cv2.resize( cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) if return_inverse_affine: dst_h, dst_w = int(output_size * h_ratio), output_size template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) # use cv2.LMEDS method for the equivalence to skimage transform # ref: https://blog.csdn.net/yichxi/article/details/115827338 affine_matrix = cv2.estimateAffinePartial2D( quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] inverse_affine = cv2.invertAffineTransform(affine_matrix) else: inverse_affine = None return cropped_face, inverse_affine def paste_face_back(img, face, inverse_affine): h, w = img.shape[0:2] face_h, face_w = face.shape[0:2] inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) mask = np.ones((face_h, face_w, 3), dtype=np.float32) inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) # remove the black borders inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) inv_restored_remove_border = inv_mask_erosion * inv_restored total_face_area = np.sum(inv_mask_erosion) // 3 # compute the fusion edge based on the area of face w_edge = int(total_face_area**0.5) // 20 erosion_radius = w_edge * 2 inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) blur_size = w_edge * 2 inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img # float32, [0, 255] return img ================================================ FILE: modules/facelib/utils/misc.py ================================================ import cv2 import os import os.path as osp import numpy as np from PIL import Image import torch from torch.hub import download_url_to_file, get_dir from urllib.parse import urlparse # from basicsr.utils.download_util import download_file_from_google_drive ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def download_pretrained_models(file_ids, save_path_root): import gdown os.makedirs(save_path_root, exist_ok=True) for file_name, file_id in file_ids.items(): file_url = 'https://drive.google.com/uc?id='+file_id save_path = osp.abspath(osp.join(save_path_root, file_name)) if osp.exists(save_path): user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') if user_response.lower() == 'y': print(f'Covering {file_name} to {save_path}') gdown.download(file_url, save_path, quiet=False) # download_file_from_google_drive(file_id, save_path) elif user_response.lower() == 'n': print(f'Skipping {file_name}') else: raise ValueError('Wrong input. Only accepts Y/N.') else: print(f'Downloading {file_name} to {save_path}') gdown.download(file_url, save_path, quiet=False) # download_file_from_google_drive(file_id, save_path) def imwrite(img, file_path, params=None, auto_mkdir=True): """Write image to file. Args: img (ndarray): Image array to be written. file_path (str): Image file path. params (None or list): Same as opencv's :func:`imwrite` interface. auto_mkdir (bool): If the parent folder of `file_path` does not exist, whether to create it automatically. Returns: bool: Successful or not. """ if auto_mkdir: dir_name = os.path.abspath(os.path.dirname(file_path)) os.makedirs(dir_name, exist_ok=True) return cv2.imwrite(file_path, img, params) def img2tensor(imgs, bgr2rgb=True, float32=True): """Numpy array to tensor. Args: imgs (list[ndarray] | ndarray): Input images. bgr2rgb (bool): Whether to change bgr to rgb. float32 (bool): Whether to change to float32. Returns: list[tensor] | tensor: Tensor images. If returned results only have one element, just return tensor. """ def _totensor(img, bgr2rgb, float32): if img.shape[2] == 3 and bgr2rgb: if img.dtype == 'float64': img = img.astype('float32') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img.transpose(2, 0, 1)) if float32: img = img.float() return img if isinstance(imgs, list): return [_totensor(img, bgr2rgb, float32) for img in imgs] else: return _totensor(imgs, bgr2rgb, float32) def load_file_from_url(url, model_dir=None, progress=True, file_name=None): """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py """ if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) if file_name is not None: filename = file_name cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file def scandir(dir_path, suffix=None, recursive=False, full_path=False): """Scan a directory to find the interested files. Args: dir_path (str): Path of the directory. suffix (str | tuple(str), optional): File suffix that we are interested in. Default: None. recursive (bool, optional): If set to True, recursively scan the directory. Default: False. full_path (bool, optional): If set to True, include the dir_path. Default: False. Returns: A generator for all the interested files with relative paths. """ if (suffix is not None) and not isinstance(suffix, (str, tuple)): raise TypeError('"suffix" must be a string or tuple of strings') root = dir_path def _scandir(dir_path, suffix, recursive): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): if full_path: return_path = entry.path else: return_path = osp.relpath(entry.path, root) if suffix is None: yield return_path elif return_path.endswith(suffix): yield return_path else: if recursive: yield from _scandir(entry.path, suffix=suffix, recursive=recursive) else: continue return _scandir(dir_path, suffix=suffix, recursive=recursive) def is_gray(img, threshold=10): img = Image.fromarray(img) if len(img.getbands()) == 1: return True img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16) img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16) img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16) diff1 = (img1 - img2).var() diff2 = (img2 - img3).var() diff3 = (img3 - img1).var() diff_sum = (diff1 + diff2 + diff3) / 3.0 if diff_sum <= threshold: return True else: return False def rgb2gray(img, out_channel=3): r, g, b = img[:,:,0], img[:,:,1], img[:,:,2] gray = 0.2989 * r + 0.5870 * g + 0.1140 * b if out_channel == 3: gray = gray[:,:,np.newaxis].repeat(3, axis=2) return gray def bgr2gray(img, out_channel=3): b, g, r = img[:,:,0], img[:,:,1], img[:,:,2] gray = 0.2989 * r + 0.5870 * g + 0.1140 * b if out_channel == 3: gray = gray[:,:,np.newaxis].repeat(3, axis=2) return gray ================================================ FILE: modules/files_cache.py ================================================ import itertools import os from collections import UserDict from dataclasses import dataclass, field from typing import Callable, Dict, Iterator, List, Optional, Union from installer import log do_cache_folders = os.environ.get('SD_NO_CACHE', None) is None class Directory: # forward declaration ... FilePathList = List[str] FilePathIterator = Iterator[str] DirectoryPathList = List[str] DirectoryPathIterator = Iterator[str] DirectoryList = List[Directory] DirectoryIterator = Iterator[Directory] DirectoryCollection = Dict[str, Directory] ExtensionFilter = Callable ExtensionList = list[str] RecursiveType = Union[bool,Callable] def real_path(directory_path:str) -> Union[str, None]: try: return os.path.abspath(os.path.expanduser(directory_path)) except Exception: pass return None @dataclass(frozen=True) class Directory(Directory): # pylint: disable=E0102 path: str = field(default_factory=str) mtime: float = field(default_factory=float, init=False) files: FilePathList = field(default_factory=list) directories: DirectoryPathList = field(default_factory=list) def __post_init__(self): object.__setattr__(self, 'mtime', self.live_mtime) @classmethod def from_dict(cls, dict_object: dict) -> Directory: directory = cls.__new__(cls) object.__setattr__(directory, 'path', dict_object.get('path')) object.__setattr__(directory, 'mtime', dict_object.get('mtime')) object.__setattr__(directory, 'files', dict_object.get('files')) object.__setattr__(directory, 'directories', dict_object.get('directories')) return directory def clear(self) -> None: self._update(Directory.from_dict({ 'path': None, 'mtime': float(), 'files': [], 'directories': [] })) def update(self, source_directory: Directory) -> Directory: if source_directory is not self: self._update(source_directory) return self def _update(self, source:Directory) -> None: assert not source.path or source.path == self.path, f'When updating a directory, the paths must match. Attemped to update Directory `{self.path}` with `{source.path}`' for dead_path in self.directories: if dead_path not in source.directories: delete_cached_directory(dead_path) self.directories[:] = source.directories self.files[:] = source.files object.__setattr__(self, 'mtime', source.mtime) @property def exists(self) -> bool: return self.path and os.path.exists(self.path) @property def is_directory(self) -> bool: return self.exists and os.path.isdir(self.path) @property def live_mtime(self) -> float: return os.path.getmtime(self.path) if self.is_directory else 0 @property def is_stale(self) -> bool: return not self.is_directory or self.mtime != self.live_mtime class DirectoryCache(UserDict, DirectoryCollection): def __delattr__(self, directory_path: str) -> None: directory: Directory = get_directory(directory_path, fetch=False) if directory: map(delete_cached_directory, directory.directories) directory.clear() del self.data[directory_path] def clean_directory(directory: Directory, /, recursive: RecursiveType=False) -> bool: if not directory.is_directory: is_clean = False delete_cached_directory(directory.path) else: is_clean = not directory.is_stale if not is_clean: directory.update(fetch_directory(directory.path)) else: for directory_path in directory.directories[:]: try: recurse = recursive and (not callable(recursive) or recursive(directory.path)) directory = get_directory(directory_path, fetch=recurse) if directory: if directory.is_directory: if recurse: is_clean = clean_directory(directory, recursive=recurse) and is_clean continue delete_cached_directory(directory_path) # If we had intended to fetch this directory, but didn't, that means it doesn't exist. Purge. if recurse: directory.directories.remove(directory_path) is_clean = False except Exception: pass return is_clean def get_directory(directory_or_path: str, /, fetch: bool=True) -> Union[Directory, None]: if isinstance(directory_or_path, Directory): if directory_or_path.is_directory: return directory_or_path else: directory_or_path = directory_or_path.path directory_or_path = real_path(directory_or_path) if not cache_folders.get(directory_or_path, None): if fetch: directory = fetch_directory(directory_path=directory_or_path) if directory and do_cache_folders: cache_folders[directory_or_path] = directory return directory else: clean_directory(cache_folders[directory_or_path]) return cache_folders[directory_or_path] if directory_or_path in cache_folders else None def fetch_directory(directory_path: str) -> Union[Directory, None]: directory: Directory for directory in _walk(directory_path, recurse=False): return directory # The return is intentional, we get a generator, we only need the one return None def _walk(top, recurse:RecursiveType=True) -> Directory: # reimplemented `path.walk()` nondirs = [] walk_dirs = [] try: scandir_it = os.scandir(top) except OSError: return with scandir_it: while True: try: entry = next(scandir_it) except StopIteration: break if not entry.is_dir(): nondirs.append(entry.path) else: if entry.is_symlink() and not os.path.exists(entry.path): log.error(f'Files broken symlink: {entry.path}') else: walk_dirs.append(entry.path) yield Directory(top, nondirs, walk_dirs) if recurse: for new_path in walk_dirs: if callable(recurse) and not recurse(new_path): continue yield from _walk(new_path, recurse=recurse) def _cached_walk(top, recurse:RecursiveType=True) -> Directory: top = get_directory(top) if not top: return yield top if recurse: for child_directory in top.directories: if os.path.basename(child_directory).startswith('models--'): continue if callable(recurse) and not recurse(child_directory): continue yield from _cached_walk(child_directory, recurse=recurse) def walk(top, recurse:RecursiveType=True, cached=True) -> Directory: yield from _cached_walk(top, recurse=recurse) if cached else _walk(top, recurse=recurse) def delete_cached_directory(directory_path:str) -> bool: global cache_folders # pylint: disable=W0602 if directory_path in cache_folders: del cache_folders[directory_path] def is_directory(dir_path:str) -> bool: return dir_path and os.path.exists(dir_path) and os.path.isdir(dir_path) def directory_mtime(directory_path:str, /, recursive:RecursiveType=True) -> float: return float(max(0, *[directory.mtime for directory in get_directories(directory_path, recursive=recursive)])) def unique_directories(directories:DirectoryPathList, /, recursive:RecursiveType=True) -> DirectoryPathIterator: '''Ensure no empty, or duplicates''' '''If we are going recursive, then directories that are children of other directories are redundant''' ''' @todo this is incredibly inneficient. the hit is small, but it is ugly, no? ''' directories = sorted(unique_paths(directories), reverse=True) while directories: directory = directories.pop() yield directory if not recursive: continue _directory = os.path.join(directory, '') child_directory = None while directories and directories[-1].startswith(_directory): if not callable(recursive) or not child_directory: directories.pop() continue child_directory = directories[-1][len(directory):] if child_directory: next_directory = _directory if not callable(recursive): _remove_directory = next_directory else: for sub_directory in child_directory.split(os.path.sep): next_directory = os.path.join(next_directory, sub_directory) if recursive(next_directory): _remove_directory = os.path.join(next_directory, '') break while _remove_directory and directories: _d = directories.pop() if not directories[-1].startswith(_remove_directory): del _remove_directory def unique_paths(directory_paths:DirectoryPathList) -> DirectoryPathIterator: realpaths = (real_path(directory_path) for directory_path in filter(bool, directory_paths)) return {real_directory_path: True for real_directory_path in filter(bool, realpaths)}.keys() def get_directories(*directory_paths: DirectoryPathList, fetch:bool=True, recursive:RecursiveType=True) -> DirectoryCollection: directory_paths = unique_directories(directory_paths, recursive=recursive) directories = (get_directory(directory_path, fetch=fetch) for directory_path in directory_paths) return filter(bool, directories) def directory_files(*directories_or_paths: Union[DirectoryPathList, DirectoryList], recursive: RecursiveType=True) -> FilePathIterator: return itertools.chain.from_iterable( itertools.chain( directory_object.files, [] if not recursive else itertools.chain.from_iterable( directory_files(directory, recursive=recursive) for directory in filter( bool, map(get_directory, filter(((bool if recursive else False) if not callable(recursive) else recursive), directory_object.directories)) ) ) ) for directory_object in filter(bool, map(get_directory, directories_or_paths)) ) def extension_filter(ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> ExtensionFilter: if ext_filter: ext_filter = [*map(str.upper, ext_filter)] if ext_blacklist: ext_blacklist = [*map(str.upper, ext_blacklist)] def filter_functon(fp:str): return (not ext_filter or any(fp.upper().endswith(ew) for ew in ext_filter)) and (not ext_blacklist or not any(fp.upper().endswith(ew) for ew in ext_blacklist)) return filter_functon def not_hidden(filepath: str) -> bool: return not os.path.basename(filepath).startswith('.') def filter_files(file_paths: FilePathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None) -> FilePathIterator: return filter(extension_filter(ext_filter, ext_blacklist), file_paths) def list_files(*directory_paths:DirectoryPathList, ext_filter: Optional[ExtensionList]=None, ext_blacklist: Optional[ExtensionList]=None, recursive:RecursiveType=True) -> FilePathIterator: return filter_files(itertools.chain.from_iterable( directory_files(directory, recursive=recursive) for directory in get_directories(*directory_paths, recursive=recursive) ), ext_filter, ext_blacklist) cache_folders = DirectoryCache({}) ================================================ FILE: modules/flash_attn_triton_amd/__init__.py ================================================ ================================================ FILE: modules/flash_attn_triton_amd/fwd_prefill.py ================================================ from typing import Literal, Optional, Union import torch import triton import triton.language as tl from modules.flash_attn_triton_amd.utils import AUTOTUNE, compute_alibi_block, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_rdna # Convenience function to load with optional boundary checks. # "First" is the major dim, "second" is the minor dim. @triton.jit def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): if offset_first is not None and offset_second is not None: mask = (offset_first[:, None] < boundary_first) & \ (offset_second[None, :] < boundary_second) tensor = tl.load(ptrs, mask=mask, other=0.0) elif offset_first is not None: mask = offset_first[:, None] < boundary_first tensor = tl.load(ptrs, mask=mask, other=0.0) elif offset_second is not None: mask = offset_second[None, :] < boundary_second tensor = tl.load(ptrs, mask=mask, other=0.0) else: tensor = tl.load(ptrs) return tensor @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, # pylint: disable=unused-argument IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. if MASK_STEPS: k_offs_n = start_n + tl.arange(0, BLOCK_N) else: k_offs_n = None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n if MASK_STEPS: # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. # last step might get wasted but that is okay. check if this masking works For # that case. if start_n + BLOCK_N == block_max and n_extra_tokens != 0: boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) # compute masks q_mask = OFFS_M[:, None] < actual_seqlen_q k_mask = (start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k p_mask = q_mask & k_mask # -- compute qk ---- qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) qk_scaled += bias if USE_ALIBI: # compute the global position of each token within the sequence global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) qk_scaled += alibi_block # get max scores so far m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) # scale and subtract max q_shifted = qk_scaled - m_ij[:, None] # Compute scaled QK and softmax probabilities if USE_EXP2: p = tl.math.exp2(q_shifted * RCP_LN2) else: p = tl.math.exp(q_shifted) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: rng_output = tl.rand(philox_seed, philox_ptrs) dropout_mask = rng_output > dropout_p # return scores with negative values for dropped vals sd_mask = tl.where(dropout_mask, p, -p) tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) # apply dropout mask in place p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that tl.store(sd_mask_ptrs, p, mask=p_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes # store the diff in maxes to adjust acc and li as we discover new maxes m_diff = m_i - m_ij if USE_EXP2: alpha = tl.math.exp2(m_diff * RCP_LN2) else: alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij acc += tl.dot(p.to(v.type.element_ty), v) k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: bias_ptrs += BLOCK_N * stride_bn if RETURN_SCORES: sd_mask_ptrs += BLOCK_N * stride_sn if ENABLE_DROPOUT: dropout_mask_ptrs += BLOCK_N * stride_sn philox_ptrs += BLOCK_N * stride_sn return acc, l_i, m_i def get_cdna_autotune_configs(): return [ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] def get_rdna_autotune_configs(): return [ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] def get_autotune_configs(): if AUTOTUNE: if is_rdna(): return get_rdna_autotune_configs() elif is_cdna(): return get_cdna_autotune_configs() else: raise ValueError("Unknown Device Type") else: return [ triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, num_stages=1, num_warps=4, ), ], [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", "ACTUAL_BLOCK_DMODEL", "IS_VARLEN", "HQ", "HK", ] autotune_configs, autotune_keys = get_autotune_configs() @triton.autotune( configs=autotune_configs, key=autotune_keys, # use_cuda_graph=True, ) @triton.jit def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, # pylint: disable=unused-argument SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, # pylint: disable=unused-argument stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 # compute offsets start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) # handle seqlen if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start elif IS_INFERENCE: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 seqlen_q = MAX_SEQLENS_Q seqlen_k = tl.load(Cache_seqlens + off_z) else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K # Now we compute whether we need to exit early due to causal masking. # This is because for seqlen_q > seqlen_k, M rows of the attn scores # are completely masked, resulting in 0s written to the output, and # inf written to LSE. We don't need to do any GEMMs in this case. # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = tl.cdiv(seqlen_k, BLOCK_N) if IS_CAUSAL: # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. # This captures the decrease in n_blocks if we have a rectangular attn matrix n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is part of # the blocks that are all 0. We exit early. if n_blocks <= 0: o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result tl.store(o_ptrs, acc, mask=o_ptrs_mask) # The tensor allocated for L is based on MAX_SEQLENS_Q as that is # statically known. l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m l_ptrs = l_offset + offs_m * stride_lse_m l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE) # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) # lse_mask = mask_m_offsets < causal_start_idx # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) l_ptrs_mask = offs_m < MAX_SEQLENS_Q tl.store(l_ptrs, l, mask=l_ptrs_mask) return # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK if GROUP_SIZE != 1: off_h_k = off_h_q // GROUP_SIZE else: off_h_k = off_h_q n_extra_tokens = 0 # print("n_extra_tokens:", n_extra_tokens) # print("seqlen_k:", seqlen_k) # print("BLOCK_N:", BLOCK_N) # return if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn else: bias_ptrs = None if USE_ALIBI: a_offset = off_z * stride_az + off_h_q * stride_ah alibi_slope = tl.load(alibi_slopes + a_offset) else: alibi_slope = None if RETURN_SCORES: sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: sd_mask_ptrs = None if ENABLE_DROPOUT: dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: dropout_mask_ptrs = None philox_ptrs = 0 # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) if IS_CAUSAL: # There are always at least BLOCK_M // BLOCK_N masked blocks. # Additionally there might be one more due to dissimilar seqlens. masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. # In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N # Compute for full blocks. Here we set causal to false regardless of its actual # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ block_min, block_max, 0, 0, 0, alibi_slope, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. if masked_blocks > 0: if IS_CAUSAL: offs_n_causal = offs_n + (seqlen_q - seqlen_k) else: offs_n_causal = 0 k_ptrs += n_full_blocks * BLOCK_N * stride_kn v_ptrs += n_full_blocks * BLOCK_N * stride_vk if USE_BIAS: bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_SCORES: sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn if ENABLE_DROPOUT: dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn philox_ptrs += n_full_blocks * BLOCK_N * stride_sn acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) # epilogue # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. l_recip = 1 / l_i[:, None] acc = acc * l_recip if ENABLE_DROPOUT: dropout_scale = 1 / (1 - dropout_p) acc = acc * dropout_scale # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, # then we have one block with a row of all NaNs which come from computing # softmax over a row of all -infs (-inf - inf = NaN). We check for that here # and store 0s where there are NaNs as these rows should've been zeroed out. end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE(Log Sum Exponents), the log of the normalization constant l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m l_ptrs = l_offset + offs_m * stride_lse_m if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 LN2: tl.constexpr = 0.6931471824645996 # compute log-sum-exp in base 2 units mi_base2 = m_i * RCP_LN2 softmax_lse = mi_base2 + tl.math.log2(l_i) # convert back to natural units softmax_lse *= LN2 else: softmax_lse = m_i + tl.math.log(l_i) if IS_CAUSAL: # zero out nans caused by -infs when doing causal lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. # This is only true for the last M block. For others, overflow_size will be -ve overflow_size = end_m_idx - seqlen_q if overflow_size > 0: boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant else: tl.store(l_ptrs, softmax_lse) # the log of the normalization constant # write back O o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) if PADDED_HEAD: o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def attention_prefill_forward_triton_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, sm_scale: float, alibi_slopes: Optional[torch.Tensor], causal: bool, bias: Optional[torch.Tensor], layout: Literal["bshd", "bhsd", "thd"], # varlen cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_k: Optional[torch.Tensor], max_seqlens_q: int, max_seqlens_k: int, # inference cache_seqlens: Optional[Union[(int, torch.Tensor)]], cache_batch_idx: Optional[torch.Tensor], # dropout dropout_p: float, philox_seed: Optional[int], philox_offset: Optional[int], # misc return_softmax: bool, use_exp2: bool, ): # check flags is_varlen = layout == "thd" use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) is_inference = cache_seqlens is not None if is_inference: assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout" # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (bias is not None): assert (bias.numel() < 2**31) batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (head_size - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing # only. This return holds no useful output aside from debugging. use_dropout = (dropout_p > 0.0) if use_dropout or return_softmax: sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, dtype=torch.float32) dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, dtype=torch.float32) scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) else: sd_mask = None dropout_mask = None scores_strides = (0, 0, 0, 0) # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) if is_varlen: softmax_lse = torch.zeros((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) stride_lse_m, stride_lse_h = softmax_lse.stride() stride_lse_z = 0 else: softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() if bias is not None: bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), bias.stride(3)) else: bias_strides = (0, 0, 0, 0) attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax) ================================================ FILE: modules/flash_attn_triton_amd/interface_fa.py ================================================ import torch from modules.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl from modules.flash_attn_triton_amd.utils import MetaData def fwd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool ): # Setup metadata metadata = MetaData(sm_scale=softmax_scale) metadata.max_seqlens_q = q.shape[1] metadata.max_seqlens_k = k.shape[1] metadata.layout = "bshd" if causal: metadata.need_causal(True) if dropout_p > 0.0: metadata.need_dropout(dropout_p) # check arguments metadata.check_args(q, k, v, out) # call implementation attention_prefill_forward_triton_impl( q, k, v, out, metadata.sm_scale, metadata.alibi_slopes, metadata.causal, None, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, metadata.cache_seqlens, metadata.cache_batch_idx, metadata.dropout_p, metadata.philox_seed, metadata.philox_offset, False, metadata.use_exp2) # varlen ================================================ FILE: modules/flash_attn_triton_amd/utils.py ================================================ import csv import math import torch import os import random import functools import triton import triton.language as tl from typing import Literal, Optional, Union from modules.rocm import Agent, MicroArchitecture AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') # ------------------------------- # Metadata # ------------------------------- class MetaData(): cu_seqlens_q: Optional[torch.Tensor] = None cu_seqlens_k: Optional[torch.Tensor] = None max_seqlens_q: int = 0 max_seqlens_k: int = 0 bias: Optional[torch.Tensor] = None alibi_slopes: Optional[torch.Tensor] = None causal: bool = False num_contexts = 0 varlen: bool = False layout: Optional[Literal["bshd", "bhsd", "thd"]] = None cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None cache_batch_idx = None packing: Optional[bool] = None return_scores: bool = False dropout_p: float = 0.0 philox_seed: Optional[int] = None philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. use_exp2: bool = False rotary_sin: Optional[torch.Tensor] = None rotary_cos: Optional[torch.Tensor] = None rotary_interleaved: bool = False rotary_conjunction: bool = False def __repr__(self) -> str: return (f"MetaData(\n" f" sm_scale={self.sm_scale},\n" f" cu_seqlens_q={self.cu_seqlens_q},\n" f" cu_seqlens_k={self.cu_seqlens_k},\n" f" max_seqlens_q={self.max_seqlens_q},\n" f" max_seqlens_k={self.max_seqlens_k},\n" f" bias={self.bias},\n" f" alibi_slopes={self.alibi_slopes},\n" f" causal={self.causal},\n" f" num_contexts={self.num_contexts},\n" f" varlen={self.varlen},\n" f" layout={self.layout},\n" f" cache_seqlens={self.cache_seqlens},\n" f" cache_batch_idx={self.cache_batch_idx},\n" f" dropout_p={self.dropout_p},\n" f" return_scores={self.return_scores}\n" f")") def __init__(self, sm_scale=1.0): self.sm_scale = sm_scale def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): self.varlen = True self.layout = 'thd' self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k self.max_seqlens_q = max_seqlen_q self.max_seqlens_k = max_seqlen_k # Without "varlen", there should still be one sequence. assert len(cu_seqlens_q) >= 2 assert len(cu_seqlens_q) == len(cu_seqlens_k) def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): assert bias.is_cuda assert bias.dim() == 4 assert bias.shape[0] == 1 assert bias.shape[2:] == (seqlen_q, seqlen_k) self.bias = bias def need_alibi(self, alibi_slopes, batch, nheads): assert alibi_slopes.is_cuda assert alibi_slopes.dim() == 2 assert alibi_slopes.shape[0] == batch assert alibi_slopes.shape[1] == nheads self.alibi_slopes = alibi_slopes def need_causal(self, causal): self.causal = causal def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): self.rotary_sin = sin self.rotary_cos = cos self.rotary_interleaved = rotary_interleaved self.rotary_conjunction = rotary_conjunction def need_dropout(self, dropout_p, return_scores = True): if dropout_p > 0.0: self.dropout_p = dropout_p self.return_scores = return_scores self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) if self.varlen: assert q.dim() == 3 assert self.cu_seqlens_q is not None assert self.cu_seqlens_k is not None assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) assert self.bias is None # assert not self.return_scores else: assert q.dim() == 4 assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None assert k.shape == v.shape assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] assert q.dtype == k.dtype and q.dtype == v.dtype assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen # ------------------------------- # Input Helper # ------------------------------- def random_seqlens_composition(SEQ_LEN, BATCH): # generate a random composition of N into Z positive parts. idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 idx, _ = torch.sort(idx) breakpoints = torch.cat([ torch.tensor([0], dtype=torch.long), idx, torch.tensor([SEQ_LEN], dtype=torch.long), ]) seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) return seqlens def generate_varlen_tensor( total_seqlen: int, num_heads: int, head_size: int, batch_size: Optional[int] = None, equal_seqlens: bool = False, device: str = "cuda", dtype: torch.dtype = torch.float32, DEBUG_INPUT: bool = False ): # get valid batch_size if batch_size is None: valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] batch_size = random.choice(valid_batch_sizes) # get seqlens if equal_seqlens: seqlens = torch.full( (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device ) seqlens[-1] += total_seqlen % batch_size else: seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) # create cumulative sequence lengths cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) max_seqlen = torch.max(seqlens).to(torch.int32).item() # create varlen tensor if DEBUG_INPUT: x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) for i in range(batch_size): start = cu_seqlens[i].item() end = cu_seqlens[i+1].item() length = end - start x[start:end, :, :] = ( torch.arange(length, dtype=dtype, device=device) .view(length, 1, 1) .expand(length, num_heads, head_size) ) else: x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device) x.requires_grad_() return x, cu_seqlens, max_seqlen def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): # gen tensor tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) if DEBUG_INPUT: x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous() else: x = torch.randn(tensor_shape, dtype=dtype, device=device) x.requires_grad_() return x def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): # gen tensor tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) if DEBUG_INPUT: x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() else: x = torch.randn(tensor_shape, dtype=dtype, device=device) x.requires_grad_() return x def input_helper( BATCH: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX_K: int, D_HEAD: int, CAUSAL: bool, DROPOUT_P: float, dtype: torch.dtype, layout: Literal["bshd", "bhsd", "thd"], packing: Optional[Literal["kv", "qkv"]] = None, device: Literal["cpu", "cuda"] = "cuda", DEBUG_INPUT: bool = False, ): torch.manual_seed(20) if layout == "thd": # set params TOTAL_SEQLENS_Q = BATCH * N_CTX_Q TOTAL_SEQLENS_K = BATCH * N_CTX_K equal_seqlens=False # gen tensors q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) # setup metadata if DEBUG_INPUT: sm_scale = 1 else: sm_scale = D_HEAD**-0.5 metadata = MetaData(sm_scale=sm_scale) metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) metadata.need_causal(CAUSAL) metadata.need_dropout(DROPOUT_P) elif layout == 'bshd' or layout == "bhsd": # gen tensors if layout == "bshd": q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) elif layout == "bhsd": q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) # setup metadata if DEBUG_INPUT: sm_scale = 1 else: sm_scale = D_HEAD**-0.5 metadata = MetaData(sm_scale=sm_scale) metadata.max_seqlens_q = N_CTX_Q metadata.max_seqlens_k = N_CTX_K metadata.layout = layout metadata.need_causal(CAUSAL) metadata.need_dropout(DROPOUT_P) else: raise ValueError(f"Unknown layout: {layout}") # deal with packing if packing is None: return q, k, v, do, metadata elif packing == "kv": # pack k and v if layout in ["bhsd", "thd"]: kv = torch.stack([k, v], dim=1) elif layout == "bshd": kv = torch.stack([k, v], dim=2) else: raise ValueError(f"Unknown layout: {layout}") return q, kv, do, metadata elif packing == "qkv": # qkv packing - requires same sequence length for q and k assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" assert HQ == HK, "For QKV packing, Q and K must have same number of heads" # pack q, k, and v if layout in ["bhsd", "thd"]: qkv = torch.stack([q, k, v], dim=1) elif layout == "bshd": qkv = torch.stack([q, k, v], dim=2) else: raise ValueError(f"Unknown layout: {layout}") return qkv, do, metadata else: assert False, f"Unsupported packing mode: {packing}" # ------------------------------- # Alibi # ------------------------------- @triton.jit def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix # for casual mask we want something like this where (1 is kept and 0 is masked) # seqlen_q = 2 and seqlen_k = 5 # 1 1 1 1 0 # 1 1 1 1 1 # seqlen_q = 5 and seqlen_k = 2 # 0 0 # 0 0 # 0 0 # 1 0 # 1 1 # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False # 1. offs_m[:,None] = [[0], # [1], # 2. offs_m[:,None] + seqlen_k = [[5], # [6], # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], # [4], # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], # [4], [ 4, 3, 2, 1, 0]] # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], # [ -4, -3, -2, -1, 0]], relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) if transpose: return alibi_block.T else: return alibi_block # ------------------------------- # Misc # ------------------------------- def get_shape_from_layout( x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ) -> tuple[int, int, int, int]: if layout == 'bhsd': batch, num_heads, max_seqlen_final, head_dim = x.shape elif layout == 'bshd': batch, max_seqlen_final, num_heads, head_dim = x.shape elif layout == 'thd': total_seqlen, num_heads, head_dim = x.shape if cu_seqlens is None: raise ValueError("cu_seqlens must be provided for varlen (thd) layout") if max_seqlen is None: raise ValueError("max_seqlen must be provided for varlen (thd) layout") batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim else: assert False, "Got unsupported layout." return batch, max_seqlen_final, num_heads, head_dim def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q) batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k) # assert assert batch_q == batch_k assert head_size_q == head_size_k return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]): if layout == 'thd': strides = (0, x.stride(1), x.stride(0), x.stride(2)) elif layout == 'bhsd': strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3)) elif layout == 'bshd': strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) else: assert False, 'Got unsupported layout.' return strides def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None): return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout) def get_strides_from_layout(q, k, v, o, layout): q_strides = get_stride_from_layout(q, layout) k_strides = get_stride_from_layout(k, layout) v_strides = get_stride_from_layout(v, layout) o_strides = get_stride_from_layout(o, layout) return q_strides, k_strides, v_strides, o_strides def get_padded_headsize(size): # Get closest power of 2 over or equal to 32. padded_d_model = 1 << (size - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) return padded_d_model def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) # ------------------------------- # Dropouts # ------------------------------- def create_dropout_mask(dropout_p, shape, seed): device = "cuda" rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32) return rand_vals > dropout_p def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed): device = "cuda" qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) max_qlen = qlens.max() max_klen = klens.max() dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) for b in range(batch): qlen = qlens[b] klen = klens[b] rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32) submask = rand_vals > dropout_p dropout_mask[b, :, :qlen, :klen] = submask return dropout_mask def write_dropout_mask(x, tensor_name = "tensor"): batch, head, seqlen_m, seqlen_n = x.shape x = x.tolist() with open(f'{tensor_name}.csv', 'w') as f: writer = csv.writer(f) for b in range(batch): for h in range(head): dropout_mask = x[b][h] if True: BLOCK_M = 64 BLOCK_N = 64 # Calculate number of blocks in each dimension m_blocks = math.ceil(seqlen_m / BLOCK_M) n_blocks = math.ceil(seqlen_n / BLOCK_N) # Process each block for m_block in range(m_blocks): # Calculate row range for current block row_start = m_block * BLOCK_M row_end = min(row_start + BLOCK_M, seqlen_m) for n_block in range(n_blocks): # Calculate column range for current block col_start = n_block * BLOCK_N col_end = min(col_start + BLOCK_N, seqlen_n) # Extract and write the current block for row_idx in range(row_start, row_end): row_data = dropout_mask[row_idx][col_start:col_end] writer.writerow(row_data) else: writer.writerows(dropout_mask) # ------------------------------- # Runtime info # ------------------------------- @functools.cache def is_cdna(): return Agent(triton.runtime.driver.active.get_current_target().arch).arch == MicroArchitecture.CDNA @functools.cache def is_rdna(): return Agent(triton.runtime.driver.active.get_current_target().arch).arch == MicroArchitecture.RDNA ================================================ FILE: modules/framepack/create-video.py ================================================ #!/usr/bin/env python import os import io import base64 import logging import argparse import requests import urllib3 from PIL import Image sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") sd_username = os.environ.get('SDAPI_USR', None) sd_password = os.environ.get('SDAPI_PWD', None) logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s') log = logging.getLogger(__name__) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) def auth(): if sd_username is not None and sd_password is not None: return requests.auth.HTTPBasicAuth(sd_username, sd_password) return None def get(endpoint: str, dct: dict = None): req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth()) if req.status_code != 200: return { 'error': req.status_code, 'reason': req.reason, 'url': req.url } else: return req.json() def post(endpoint: str, dct: dict = None): req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=None, verify=False, auth=auth()) if req.status_code != 200: return { 'error': req.status_code, 'reason': req.reason, 'url': req.url } else: return req.json() def encode(f): if not os.path.exists(f): log.error(f'file not found: {f}') os._exit(1) image = Image.open(f) if image.mode == 'RGBA': image = image.convert('RGB') with io.BytesIO() as stream: image.save(stream, 'JPEG') image.close() values = stream.getvalue() encoded = base64.b64encode(values).decode() return encoded def generate(args): # pylint: disable=redefined-outer-name request = { 'variant': args.variant, 'prompt': args.prompt, 'section_prompt': args.sections, 'init_image': encode(args.init), 'end_image': encode(args.end) if args.end else None, 'resolution': int(args.resolution), 'duration': float(args.duration), 'mp4_fps': int(args.fps), 'seed': int(args.seed), 'steps': int(args.steps), 'shift': float(args.shift), 'cfg_scale': float(args.scale), 'cfg_rescale': float(args.rescale), 'cfg_distilled': float(args.distilled), 'use_teacache': bool(args.teacache), 'vlm_enhance': bool(args.enhance), } log.info(f'request: {args}') result = post('/sdapi/v1/framepack', request) # can abandon request here and not wait for response or wait synchronously log.info(f'response: {result}') progress = get('/sdapi/v1/progress?skip_current_image=true', None) # monitor progress of the current task task_id = progress.get('id', None) log.info(f'id: {task_id}') log.info(f'progress: {progress}') outputs = [] history = get(f'/sdapi/v1/history?id={task_id}') # get history for the task for event in history: log.info(f'history: {event}') outputs = event.get('outputs', []) log.info(f'outputs: {outputs}') # you can download output files using /file={filename} endpoint if __name__ == "__main__": parser = argparse.ArgumentParser(description = 'api-framepack') parser.add_argument('--init', required=True, help='init image') parser.add_argument('--end', required=False, help='init image') parser.add_argument('--prompt', required=False, default='', help='prompt text') parser.add_argument('--sections', required=False, default='', help='per-section prompts') parser.add_argument('--resolution', type=int, required=False, default=640, help='video resolution') parser.add_argument('--duration', type=float, required=False, default=4.0, help='video duration') parser.add_argument('--fps', type=int, required=False, default=30, help='video frames per second') parser.add_argument('--seed', type=int, required=False, default=-1, help='random seed') parser.add_argument('--enhance', required=False, action='store_true', help='enable prompt enhancer') parser.add_argument('--teacache', required=False, action='store_true', help='enable teacache') parser.add_argument('--steps', type=int, default=25, help='steps') parser.add_argument('--scale', type=float, default=1.0, help='cfg scale') parser.add_argument('--rescale', type=float, default=0.0, help='cfg rescale') parser.add_argument('--distilled', type=float, default=10.0, help='cfg distilled') parser.add_argument('--shift', type=float, default=3.0, help='sampler shift') parser.add_argument('--variant', type=str, default='bi-directional', choices=['bi-directional', 'forward-only'], help='model variant') args = parser.parse_args() log.info(f'api-framepack: {args}') generate(args) ================================================ FILE: modules/framepack/encode-video.py ================================================ #!/usr/bin/env python import os import logging import argparse import cv2 import torch import torchvision from safetensors.torch import safe_open from tqdm.rich import trange logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s') log = logging.getLogger("sd") if __name__ == "__main__": parser = argparse.ArgumentParser(description = 'framepack-cli') parser.add_argument('--input', required=True, help='input safetensors') parser.add_argument('--cv2', required=False, help='encode video file using cv2') parser.add_argument('--tv', required=False, help='encode video file using torchvision') parser.add_argument('--codec', default='libx264', help='specify video codec') parser.add_argument('--export', required=False, help='export frames as images to folder') parser.add_argument('--fps', default=30, help='frames-per-second') args = parser.parse_args() log.info(f'framepack-cli: {args}') log.info(f'torch={torch.__version__} torchvision={torchvision.__version__}') with safe_open(args.input, framework="pt", device="cpu") as f: frames = f.get_tensor('frames') metadata = f.metadata() n, h, w, _c = frames.shape log.info(f'file: metadata={metadata}') log.info(f'tensor: frames={n} shape={frames.shape} dtype={frames.dtype} device={frames.device}') fn = os.path.splitext(os.path.basename(args.input))[0] if args.export: log.info(f'export: folder="{args.export}" prefix="{fn}" frames={n} width={w} height={h}') os.makedirs(args.export, exist_ok=True) for i in trange(n): image = cv2.cvtColor(frames[i].numpy(), cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(args.export, f'{fn}-{i:05d}.jpg'), image) if args.cv2: log.info(f'encode: file={args.cv2} frames={n} width={w} height={h} fps={args.fps} method=cv2') fourcc = cv2.VideoWriter_fourcc(*'mp4v') video = cv2.VideoWriter(args.cv2, fourcc, args.fps, (w, h)) for i in trange(n): image = cv2.cvtColor(frames[i].numpy(), cv2.COLOR_RGB2BGR) video.write(image) video.release() if args.tv: log.info(f'encode: file={args.tv} frames={n} width={w} height={h} fps={args.fps} method=tv ') torchvision.io.write_video(args.tv, video_array=frames, fps=args.fps, video_codec=args.codec) ================================================ FILE: modules/framepack/framepack_api.py ================================================ from typing import Optional, List from pydantic import BaseModel, Field # pylint: disable=no-name-in-module from fastapi.exceptions import HTTPException from modules import shared class ReqFramepack(BaseModel): variant: str = Field(default=None, title="Model variant", description="Model variant to use") prompt: str = Field(default=None, title="Prompt", description="Prompt for the model") init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image") end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image") start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image") end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image") vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model") system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model") optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model") section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section") negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model") styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model") seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model") resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image") duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds") latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window") steps: Optional[int] = Field(default=25, title="Video steps", description="Number of steps for the video generation") cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model") cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model") cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model") shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler") use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model") use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model") mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video") mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video") mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video") mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video") mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video") mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec") mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video") mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video") attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model") vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model") vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance") vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use") vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model") class ResFramepack(BaseModel): id: str = Field(title="TaskID", description="Task ID") filename: str = Field(title="TaskID", description="Task ID") message: str = Field(title="TaskID", description="Task ID") def framepack_post(request: ReqFramepack): import numpy as np from modules.api import helpers from framepack_wrappers import run_framepack task_id = shared.state.get_id() try: if request.init_image is not None: init_image = np.array(helpers.decode_base64_to_image(request.init_image)) if request.init_image else None else: init_image = None except Exception as e: shared.log.error(f"API FramePack: id={task_id} cannot decode init image: {e}") raise HTTPException(status_code=500, detail=str(e)) from e try: if request.end_image is not None: end_image = np.array(helpers.decode_base64_to_image(request.end_image)) if request.end_image else None else: end_image = None except Exception as e: shared.log.error(f"API FramePack: id={task_id} cannot decode end image: {e}") raise HTTPException(status_code=500, detail=str(e)) from e del request.init_image del request.end_image shared.log.trace(f"API FramePack: id={task_id} init={init_image.shape} end={end_image.shape if end_image else None} {request}") generator = run_framepack( _ui_state=None, task_id=f'task({task_id})', variant=request.variant, init_image=init_image, end_image=end_image, start_weight=request.start_weight, end_weight=request.end_weight, vision_weight=request.vision_weight, prompt=request.prompt, system_prompt=request.system_prompt, optimized_prompt=request.optimized_prompt, section_prompt=request.section_prompt, negative_prompt=request.negative_prompt, styles=request.styles, seed=request.seed, resolution=request.resolution, duration=request.duration, latent_ws=request.latent_ws, steps=request.steps, cfg_scale=request.cfg_scale, cfg_distilled=request.cfg_distilled, cfg_rescale=request.cfg_rescale, shift=request.shift, use_teacache=request.use_teacache, use_cfgzero=request.use_cfgzero, use_preview=False, mp4_fps=request.mp4_fps, mp4_codec=request.mp4_codec, mp4_sf=request.mp4_sf, mp4_video=request.mp4_video, mp4_frames=request.mp4_frames, mp4_opt=request.mp4_opt, mp4_ext=request.mp4_ext, mp4_interpolate=request.mp4_interpolate, attention=request.attention, vae_type=request.vae_type, vlm_enhance=request.vlm_enhance, vlm_model=request.vlm_model, vlm_system_prompt=request.vlm_system_prompt, ) response = ResFramepack(id=task_id, filename='', message='') for message in generator: if isinstance(message, tuple) and len(message) == 3: if isinstance(message[0], str): response.filename = message[0] if isinstance(message[2], str): response.message = message[2] return response def create_api(_fastapi, _gradioapp): shared.api.add_api_route("/sdapi/v1/framepack", framepack_post, methods=["POST"], response_model=ResFramepack) ================================================ FILE: modules/framepack/framepack_hijack.py ================================================ DEFAULT_PROMPT_TEMPLATE = { # hunyuanvideo reference prompt template "template": ( "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " "1. The main content and theme of the video." "2. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "3. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." "4. camera angles, movements, and transitions used in the video:<|eot_id|>" "5. background environment, light, style and atmosphere." "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" ), "crop_start": 95, } ORIGINAL_PROMPT_TEMPLATE = ( "\n\nDescribe the video by detailing the following aspects: " "1. The main content and theme of the video." "2. Actions, events, behaviors temporal relationships, physical movement changes of the objects." "3. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." "4. camera angles, movements, and transitions used in the video:<|eot_id|>" "5. background environment, light, style and atmosphere." ) def set_progress_bar_config(): from modules.framepack.pipeline import uni_pc_fm def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): # pylint: disable=unused-argument return uni_pc_fm.FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=True) uni_pc_fm.sample_unipc = sample_unipc def set_prompt_template(prompt, system_prompt:str=None, optimized_prompt:bool=True, unmodified_prompt:bool=False): from modules import shared from modules.framepack.pipeline import hunyuan mode = 'unknown' if system_prompt is not None and len(system_prompt) > 0: mode = 'custom' system_prompt = system_prompt.strip() elif not optimized_prompt: mode = 'original' system_prompt = ORIGINAL_PROMPT_TEMPLATE # elif unmodified_prompt: # mode = 'enhanced' # system_prompt = "Repeat the user prompt without any modification. Do not add any additional information or context. " else: mode = 'optimized' system_prefix = ( "You are a video director. " "Both subject and camera movement should be dynamic and creative. Do not specify duration and assume video is short. Avoid slow-motion and prefer faster movements. " "Do not alter the content specified in the user prompt, but use it as a reference to create a video. " "NSFW content and nudity are allowed and must be fully described if present. " "Based on provided user prompt, describe the video by detailing the following aspects: \n" ) system_desc = ( # "Describe the video by detailing the following aspects: \n" "1. Main content, style and theme of the video.\n" "2. Actions, events, behaviors, temporal relationships, physical movement, and changes of the subjects or objects.\n" "3. Camera angles, camera movements, and transitions used in the video.\n" "4. Details of the scene and background environment, light, style, and atmosphere.\n" ) system_prompt = system_prefix + system_desc # system_prompt = DEFAULT_PROMPT_TEMPLATE["template"] inputs = shared.sd_model.tokenizer(system_prompt, max_length=256, truncation=True, return_tensors="pt", return_length=True, return_overflowing_tokens=False, return_attention_mask=False) tokens_system = inputs['length'].item() - int(shared.sd_model.tokenizer.bos_token_id is not None) - int(shared.sd_model.tokenizer.eos_token_id is not None) inputs = shared.sd_model.tokenizer(prompt, max_length=256, truncation=True, return_tensors="pt", return_length=True, return_overflowing_tokens=False, return_attention_mask=False) hunyuan.DEFAULT_PROMPT_TEMPLATE = { "template": ( f"<|start_header_id|>system<|end_header_id|>{system_prompt}\n<|eot_id|>" "<|start_header_id|>user<|end_header_id|>{}<|eot_id|>" ), "crop_start": tokens_system, } tokens_user = inputs['length'].item() - int(shared.sd_model.tokenizer.bos_token_id is not None) - int(shared.sd_model.tokenizer.eos_token_id is not None) shared.log.trace(f'FramePack prompt: system={tokens_system} user={tokens_user} optimized={optimized_prompt} unmodified={unmodified_prompt} mode={mode}') ================================================ FILE: modules/framepack/framepack_install.py ================================================ import os import shutil import git as gitpython from installer import install, git from modules.shared import log def rename(src:str, dst:str): import errno try: os.rename(src, dst) except OSError as e: if e.errno == errno.EXDEV: # cross-device shutil.move(src, dst) else: raise e def install_requirements(attention:str='SDPA'): install('av') import av import torchvision torchvision.io.video.av = av if attention == 'Xformers': log.debug('FramePack install: xformers') install('xformers') elif attention == 'FlashAttention': log.debug('FramePack install: flash-attn') install('flash-attn') elif attention == 'SageAttention': log.debug('FramePack install: sageattention') install('sageattention') def git_clone(git_repo:str, git_dir:str, tmp_dir:str): if os.path.exists(git_dir): return try: shutil.rmtree(tmp_dir, True) args = { 'url': git_repo, 'to_path': tmp_dir, 'allow_unsafe_protocols': True, 'allow_unsafe_options': True, 'filter': ['blob:none'], } ssh = os.environ.get('GIT_SSH_COMMAND', None) if ssh: args['env'] = {'GIT_SSH_COMMAND':ssh} log.info(f'FramePack install: url={args} path={git_repo}') with gitpython.Repo.clone_from(**args) as repo: repo.remote().fetch(verbose=True) for submodule in repo.submodules: submodule.update() rename(tmp_dir, git_dir) except Exception as e: log.error(f'FramePack install: {e}') shutil.rmtree(tmp_dir, True) def git_update(git_dir:str, git_commit:str): if not os.path.exists(git_dir): return try: with gitpython.Repo(git_dir) as repo: commit = repo.commit() if f'{commit}' != git_commit: log.info(f'FramePack update: path={repo.git_dir} current={commit} target={git_commit}') repo.git.fetch(all=True) repo.git.reset('origin', hard=True) git(f'checkout {git_commit}', folder=git_dir, ignore=True, optional=True) else: log.debug(f'FramePack version: sha={commit}') except Exception as e: log.error(f'FramePack update: {e}') ================================================ FILE: modules/framepack/framepack_load.py ================================================ import os import time from modules import shared, devices, errors, sd_models, sd_checkpoint, model_quant models = { 'bi-directional': 'lllyasviel/FramePackI2V_HY', 'forward-only': 'lllyasviel/FramePack_F1_I2V_HY_20250503', } default_model = { 'pipeline': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': '' }, 'vae': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'vae' }, 'text_encoder': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'text_encoder' }, 'tokenizer': {'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'tokenizer' }, # 'text_encoder': { 'repo': 'Kijai/llava-llama-3-8b-text-encoder-tokenizer', 'subfolder': '' }, # 'tokenizer': { 'repo': 'Kijai/llava-llama-3-8b-text-encoder-tokenizer', 'subfolder': '' }, # 'text_encoder': { 'repo': 'xtuner/llava-llama-3-8b-v1_1-transformers', 'subfolder': '' }, # 'tokenizer': {'repo': 'xtuner/llava-llama-3-8b-v1_1-transformers', 'subfolder': '' }, 'text_encoder_2': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'text_encoder_2' }, 'tokenizer_2': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'tokenizer_2' }, 'feature_extractor': { 'repo': 'lllyasviel/flux_redux_bfl', 'subfolder': 'feature_extractor' }, 'image_encoder': { 'repo': 'lllyasviel/flux_redux_bfl', 'subfolder': 'image_encoder' }, 'transformer': { 'repo': models.get('bi-directional'), 'subfolder': '' }, } model = default_model.copy() def split_url(url): if url.count('/') == 1: url += '/' if url.count('/') != 2: raise ValueError(f'Invalid URL: {url}') url = [section.strip() for section in url.split('/')] return { 'repo': f'{url[0]}/{url[1]}', 'subfolder': url[2] } def set_model(receipe: str=None): if receipe is None or receipe == '': return lines = [line.strip() for line in receipe.split('\n') if line.strip() != '' and ':' in line] for line in lines: k, v = line.split(':', 1) k = k.strip() if k not in default_model.keys(): shared.log.warning(f'FramePack receipe: key={k} invalid') model[k] = split_url(v) shared.log.debug(f'FramePack receipe: set {k}={model[k]}') def get_model(): receipe = '' for k, v in model.items(): receipe += f'{k}: {v["repo"]}/{v["subfolder"]}\n' return receipe.strip() def reset_model(): global model # pylint: disable=global-statement model = default_model.copy() shared.log.debug('FramePack receipe: reset') return '' def load_model(variant:str=None, pipeline:str=None, text_encoder:str=None, text_encoder_2:str=None, feature_extractor:str=None, image_encoder:str=None, transformer:str=None): shared.state.begin('Load FramePack') if variant is not None: if variant not in models.keys(): raise ValueError(f'FramePack: variant="{variant}" invalid') model['transformer']['repo'] = models[variant] if pipeline is not None: model['pipeline'] = split_url(pipeline) if text_encoder is not None: model['text_encoder'] = split_url(text_encoder) if text_encoder_2 is not None: model['text_encoder_2'] = split_url(text_encoder_2) if feature_extractor is not None: model['feature_extractor'] = split_url(feature_extractor) if image_encoder is not None: model['image_encoder'] = split_url(image_encoder) if transformer is not None: model['transformer'] = split_url(transformer) # shared.log.trace(f'FramePack load: {model}') try: import diffusers from diffusers import HunyuanVideoImageToVideoPipeline, AutoencoderKLHunyuanVideo from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, SiglipImageProcessor, SiglipVisionModel from modules.framepack.pipeline.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked class FramepackHunyuanVideoPipeline(HunyuanVideoImageToVideoPipeline): # inherit and override def __init__( self, text_encoder: LlamaModel, tokenizer: LlamaTokenizerFast, text_encoder_2: CLIPTextModel, tokenizer_2: CLIPTokenizer, vae: AutoencoderKLHunyuanVideo, feature_extractor: SiglipImageProcessor, image_processor: SiglipVisionModel, transformer: HunyuanVideoTransformer3DModelPacked, scheduler, ): super().__init__( text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, vae=vae, transformer=transformer, image_processor=image_processor, scheduler=scheduler, ) self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, vae=vae, feature_extractor=feature_extractor, image_processor=image_processor, transformer=transformer, scheduler=scheduler, ) sd_models.unload_model_weights() t0 = time.time() sd_models.hf_auth_check(model["transformer"]["repo"]) sd_models.hf_auth_check(model["text_encoder"]["repo"]) sd_models.hf_auth_check(model["text_encoder_2"]["repo"]) offline_config = {} if shared.opts.offline_mode: offline_config["local_files_only"] = True os.environ['HF_HUB_OFFLINE'] = '1' else: os.environ.pop('HF_HUB_OFFLINE', None) os.unsetenv('HF_HUB_OFFLINE') shared.log.debug(f'FramePack load: module=llm {model["text_encoder"]}') load_args, quant_args = model_quant.get_dit_args({}, module='TE', device_map=True) text_encoder = LlamaModel.from_pretrained(model["text_encoder"]["repo"], subfolder=model["text_encoder"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **load_args, **quant_args, **offline_config) tokenizer = LlamaTokenizerFast.from_pretrained(model["tokenizer"]["repo"], subfolder=model["tokenizer"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **offline_config) text_encoder.requires_grad_(False) text_encoder.eval() sd_models.move_model(text_encoder, devices.cpu) shared.log.debug(f'FramePack load: module=te {model["text_encoder_2"]}') text_encoder_2 = CLIPTextModel.from_pretrained(model["text_encoder_2"]["repo"], subfolder=model["text_encoder_2"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, **offline_config) tokenizer_2 = CLIPTokenizer.from_pretrained(model["pipeline"]["repo"], subfolder='tokenizer_2', cache_dir=shared.opts.hfcache_dir, **offline_config) text_encoder_2.requires_grad_(False) text_encoder_2.eval() sd_models.move_model(text_encoder_2, devices.cpu) shared.log.debug(f'FramePack load: module=vae {model["vae"]}') vae = AutoencoderKLHunyuanVideo.from_pretrained(model["vae"]["repo"], subfolder=model["vae"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, **offline_config) vae.requires_grad_(False) vae.eval() vae.enable_slicing() vae.enable_tiling() sd_models.move_model(vae, devices.cpu) shared.log.debug(f'FramePack load: module=encoder {model["feature_extractor"]} model={model["image_encoder"]}') feature_extractor = SiglipImageProcessor.from_pretrained(model["feature_extractor"]["repo"], subfolder=model["feature_extractor"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **offline_config) image_encoder = SiglipVisionModel.from_pretrained(model["image_encoder"]["repo"], subfolder=model["image_encoder"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir, **offline_config) image_encoder.requires_grad_(False) image_encoder.eval() sd_models.move_model(image_encoder, devices.cpu) shared.log.debug(f'FramePack load: module=transformer {model["transformer"]}') dit_repo = model["transformer"]["repo"] load_args, quant_args = model_quant.get_dit_args({}, module='Model', device_map=True) transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(dit_repo, subfolder=model["transformer"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **load_args, **quant_args, **offline_config) transformer.high_quality_fp32_output_for_inference = False transformer.requires_grad_(False) transformer.eval() sd_models.move_model(transformer, devices.cpu) shared.sd_model = FramepackHunyuanVideoPipeline( text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, vae=vae, feature_extractor=feature_extractor, image_processor=image_encoder, transformer=transformer, scheduler=None, ) shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(dit_repo) # pylint: disable=attribute-defined-outside-init shared.sd_model.sd_model_checkpoint = dit_repo # pylint: disable=attribute-defined-outside-init shared.sd_model = model_quant.do_post_load_quant(shared.sd_model, allow=False) t1 = time.time() diffusers.loaders.peft._SET_ADAPTER_SCALE_FN_MAPPING['HunyuanVideoTransformer3DModelPacked'] = lambda model_cls, weights: weights # pylint: disable=protected-access shared.log.info(f'FramePack load: model={shared.sd_model.__class__.__name__} variant="{variant}" type={shared.sd_model_type} time={t1-t0:.2f}') sd_models.apply_balanced_offload(shared.sd_model) devices.torch_gc(force=True, reason='load') except Exception as e: shared.log.error(f'FramePack load: {e}') errors.display(e, 'FramePack') shared.state.end() return None shared.state.end() return variant def unload_model(): sd_models.unload_model_weights() ================================================ FILE: modules/framepack/framepack_ui.py ================================================ import gradio as gr from modules import ui_sections, ui_video_vlm from modules.framepack import framepack_load from modules.framepack.framepack_worker import get_latent_paddings from modules.framepack.framepack_wrappers import load_model, unload_model from modules.framepack.framepack_wrappers import run_framepack # pylint: disable=wrong-import-order def change_sections(duration, mp4_fps, mp4_interpolate, latent_ws, variant): num_sections = len(get_latent_paddings(mp4_fps, mp4_interpolate, latent_ws, duration, variant)) num_frames = (latent_ws * 4 - 3) * num_sections + 1 return gr.update(value=f'Target video: {num_frames} frames in {num_sections} sections'), gr.update(lines=max(2, 2*num_sections//3)) def create_ui(prompt, negative, styles, _overrides, init_image, last_image, mp4_fps, mp4_interpolate, mp4_codec, mp4_ext, mp4_opt, mp4_video, mp4_frames, mp4_sf): with gr.Row(): with gr.Column(variant='compact', elem_id="framepack_settings", elem_classes=['settings-column'], scale=1): with gr.Row(): generate = gr.Button('Generate', elem_id="framepack_generate_btn", variant='primary', visible=False) with gr.Row(): variant = gr.Dropdown(label="FP model variant", choices=list(framepack_load.models), value='bi-directional', type='value') with gr.Row(): resolution = gr.Slider(label="FP resolution", minimum=240, maximum=1088, value=640, step=16) duration = gr.Slider(label="FP duration", minimum=1, maximum=120, value=4, step=0.1) mp4_fps = gr.Slider(label="FP target FPS", minimum=1, maximum=60, value=24, step=1) mp4_interpolate = gr.Slider(label="FP interpolation", minimum=0, maximum=10, value=0, step=1) with gr.Row(): section_html = gr.HTML(show_label=False, elem_id="framepack_section_html") with gr.Accordion(label="Inputs", open=False): with gr.Row(): start_weight = gr.Slider(label="FP init strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_start_weight") end_weight = gr.Slider(label="FP end strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_end_weight") vision_weight = gr.Slider(label="FP vision strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_vision_weight") with gr.Accordion(label="Sections", open=False): section_prompt = gr.Textbox(label="FP section prompts", elem_id="framepack_section_prompt", lines=2, placeholder="Optional one-line prompt suffix per each video section", interactive=True) with gr.Accordion(label="Advanced", open=False): seed = ui_sections.create_seed_inputs('control', reuse_visible=False, subseed_visible=False, accordion=False)[0] latent_ws = gr.Slider(label="FP latent window size", minimum=1, maximum=33, value=9, step=1) with gr.Row(): steps = gr.Slider(label="FP steps", minimum=1, maximum=100, value=25, step=1) shift = gr.Slider(label="FP sampler shift", minimum=0.0, maximum=10.0, value=3.0, step=0.01) with gr.Row(): cfg_scale = gr.Slider(label="FP CFG scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01) cfg_distilled = gr.Slider(label="FP distilled CFG scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01) cfg_rescale = gr.Slider(label="FP CFG re-scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01) vlm_enhance, vlm_model, vlm_system_prompt = ui_video_vlm.create_ui(prompt_element=prompt, image_element=init_image) with gr.Accordion(label="Model", open=False): with gr.Row(): btn_load = gr.Button(value="Load model", elem_id="framepack_btn_load", interactive=True) btn_unload = gr.Button(value="Unload model", elem_id="framepack_btn_unload", interactive=True) with gr.Row(): system_prompt = gr.Textbox(label="FP system prompt", elem_id="framepack_system_prompt", lines=6, placeholder="Optional system prompt for the model", interactive=True) with gr.Row(): receipe = gr.Textbox(label="FP model receipe", elem_id="framepack_model_receipe", lines=6, placeholder="Model receipe", interactive=True) with gr.Row(): receipe_get = gr.Button(value="Get receipe", elem_id="framepack_btn_get_model", interactive=True) receipe_set = gr.Button(value="Set receipe", elem_id="framepack_btn_set_model", interactive=True) receipe_reset = gr.Button(value="Reset receipe", elem_id="framepack_btn_reset_model", interactive=True) use_teacache = gr.Checkbox(label='FP enable TeaCache', value=True) optimized_prompt = gr.Checkbox(label='FP use optimized system prompt', value=True) use_cfgzero = gr.Checkbox(label='FP enable CFGZero', value=False) use_preview = gr.Checkbox(label='FP enable Preview', value=True) attention = gr.Dropdown(label="FP attention", choices=['Default', 'Xformers', 'FlashAttention', 'SageAttention'], value='Default', type='value') vae_type = gr.Dropdown(label="FP VAE", choices=['Full', 'Tiny', 'Remote'], value='Local', type='value') with gr.Column(elem_id='framepack-output-column', scale=2) as _column_output: with gr.Tabs(): with gr.TabItem("Video"): result_video = gr.Video(label="Video", autoplay=True, show_share_button=False, height=512, loop=True, show_label=False, elem_id="framepack_result_video") with gr.Tab("Preview"): preview_image = gr.Image(label="Current", height=512, show_label=False, elem_id="framepack_preview_image") progress_desc = gr.HTML('', show_label=False, elem_id="framepack_progress_desc") # hidden fields task_id = gr.Textbox(visible=False, value='') ui_state = gr.Textbox(visible=False, value='') state_inputs = [task_id, ui_state] framepack_outputs = [ result_video, preview_image, progress_desc, ] duration.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt]) mp4_fps.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt]) mp4_interpolate.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt]) btn_load.click(fn=load_model, inputs=[variant, attention], outputs=framepack_outputs) btn_unload.click(fn=unload_model, outputs=framepack_outputs) receipe_get.click(fn=framepack_load.get_model, inputs=[], outputs=receipe) receipe_set.click(fn=framepack_load.set_model, inputs=[receipe], outputs=[]) receipe_reset.click(fn=framepack_load.reset_model, inputs=[], outputs=[receipe]) framepack_inputs=[ init_image, last_image, start_weight, end_weight, vision_weight, prompt, system_prompt, optimized_prompt, section_prompt, negative, styles, seed, resolution, duration, latent_ws, steps, cfg_scale, cfg_distilled, cfg_rescale, shift, use_teacache, use_cfgzero, use_preview, mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, attention, vae_type, variant, vlm_enhance, vlm_model, vlm_system_prompt, ] framepack_dict = dict( fn=run_framepack, _js="submit_framepack", inputs=state_inputs + framepack_inputs, outputs=framepack_outputs, show_progress='hidden', ) generate.click(**framepack_dict) ================================================ FILE: modules/framepack/framepack_vae.py ================================================ import torch import einops from modules import shared, devices latent_rgb_factors = [ # from comfyui [-0.0395, -0.0331, 0.0445], [0.0696, 0.0795, 0.0518], [0.0135, -0.0945, -0.0282], [0.0108, -0.0250, -0.0765], [-0.0209, 0.0032, 0.0224], [-0.0804, -0.0254, -0.0639], [-0.0991, 0.0271, -0.0669], [-0.0646, -0.0422, -0.0400], [-0.0696, -0.0595, -0.0894], [-0.0799, -0.0208, -0.0375], [0.1166, 0.1627, 0.0962], [0.1165, 0.0432, 0.0407], [-0.2315, -0.1920, -0.1355], [-0.0270, 0.0401, -0.0821], [-0.0616, -0.0997, -0.0727], [0.0249, -0.0469, -0.1703] ] latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] vae_weight = None vae_bias = None taesd = None def vae_decode_simple(latents): global vae_weight, vae_bias # pylint: disable=global-statement with devices.inference_context(): if vae_weight is None or vae_bias is None: vae_weight = torch.tensor(latent_rgb_factors, device=devices.device, dtype=devices.dtype).transpose(0, 1)[:, :, None, None, None] vae_bias = torch.tensor(latent_rgb_factors_bias, device=devices.device, dtype=devices.dtype) images = torch.nn.functional.conv3d(latents, weight=vae_weight, bias=vae_bias, stride=1, padding=0, dilation=1, groups=1) images = (images + 1.2) * 100 # sort-of normalized images = einops.rearrange(images, 'b c t h w -> (b h) (t w) c') images = images.to(torch.uint8).detach().cpu().numpy().clip(0, 255) return images def vae_decode_tiny(latents): global taesd # pylint: disable=global-statement if taesd is None: from modules.vae import sd_vae_taesd taesd, _variant = sd_vae_taesd.get_model(variant='TAE HunyuanVideo') shared.log.debug(f'Video VAE: type=Tiny cls={taesd.__class__.__name__} latents={latents.shape}') with devices.inference_context(): taesd = taesd.to(device=devices.device, dtype=devices.dtype) latents = latents.transpose(1, 2) # pipe produces NCTHW and tae wants NTCHW images = taesd.decode_video(latents, parallel=False, show_progress_bar=False) images = images.transpose(1, 2).mul_(2).sub_(1) # normalize taesd = taesd.to(device=devices.cpu, dtype=devices.dtype) return images def vae_decode_remote(latents): # from modules.vae.sd_vae_remote import remote_decode # images = remote_decode(latents, model_type='hunyuanvideo') from diffusers.utils.remote_utils import remote_decode images = remote_decode( tensor=latents.contiguous(), endpoint='https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud', output_type='pt', return_type='pt', ) return images def vae_decode_full(latents): with devices.inference_context(): vae = shared.sd_model.vae latents = (latents / vae.config.scaling_factor).to(device=devices.device, dtype=devices.dtype) images = vae.decode(latents).sample return images def vae_decode(latents, vae_type): latents = latents.to(device=devices.device, dtype=devices.dtype) if vae_type == 'Tiny': return vae_decode_tiny(latents) elif vae_type == 'Preview': return vae_decode_simple(latents) elif vae_type == 'Remote': return vae_decode_remote(latents) else: # vae_type == 'Full' jobid = shared.state.begin('VAE Decode') result = vae_decode_full(latents) shared.state.end(jobid) return result def vae_encode(image): with devices.inference_context(): vae = shared.sd_model.vae latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor return latents ================================================ FILE: modules/framepack/framepack_worker.py ================================================ import time import torch import rich.progress as rp from modules import shared, errors ,devices, sd_models, timer, memstats from modules.framepack import framepack_vae # pylint: disable=wrong-import-order from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order from modules.video_models.video_save import save_video # pylint: disable=wrong-import-order stream = None # AsyncStream def get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant): try: real_fps = mp4_fps / (mp4_interpolate + 1) is_f1 = variant == 'forward-only' if is_f1: total_latent_sections = (total_second_length * real_fps) / (latent_window_size * 4) total_latent_sections = int(max(round(total_latent_sections), 1)) latent_paddings = list(range(total_latent_sections)) else: total_latent_sections = int(max((total_second_length * real_fps) / (latent_window_size * 4), 1)) latent_paddings = list(reversed(range(total_latent_sections))) if total_latent_sections > 4: # extra padding for better quality # latent_paddings = list(reversed(range(total_latent_sections))) latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] except Exception: latent_paddings = [0] return latent_paddings def worker( input_image, end_image, start_weight, end_weight, vision_weight, prompts, n_prompt, system_prompt, optimized_prompt, unmodified_prompt, seed, total_second_length, latent_window_size, steps, cfg_scale, cfg_distilled, cfg_rescale, shift, use_teacache, use_cfgzero, use_preview, mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, vae_type, variant, metadata:dict={}, ): timer.process.reset() memstats.reset_stats() if stream is None or shared.state.interrupted or shared.state.skipped: shared.log.error('FramePack: stream is None') stream.output_queue.push(('end', None)) return from modules.framepack.pipeline import hunyuan from modules.framepack.pipeline import utils from modules.framepack.pipeline import k_diffusion_hunyuan is_f1 = variant == 'forward-only' total_generated_frames = 0 total_generated_latent_frames = 0 latent_paddings = get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant) num_frames = latent_window_size * 4 - 3 # number of frames to generate in each section metadata['title'] = 'sdnext framepack' metadata['description'] = f'variant:{variant} seed:{seed} steps:{steps} scale:{cfg_scale} distilled:{cfg_distilled} rescale:{cfg_rescale} shift:{shift} start:{start_weight} end:{end_weight} vision:{vision_weight}' videojob = shared.state.begin('Video') shared.state.job_count = 1 text_encoder = shared.sd_model.text_encoder text_encoder_2 = shared.sd_model.text_encoder_2 tokenizer = shared.sd_model.tokenizer tokenizer_2 = shared.sd_model.tokenizer_2 feature_extractor = shared.sd_model.feature_extractor image_encoder = shared.sd_model.image_processor transformer = shared.sd_model.transformer sd_models.apply_balanced_offload(shared.sd_model) pbar = rp.Progress(rp.TextColumn('[cyan]Video'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console) task = pbar.add_task('starting', total=steps * len(latent_paddings)) t_last = time.time() if not is_f1: prompts = list(reversed(prompts)) def text_encode(prompt, i:int=None): jobid = shared.state.begin('TE Encode') pbar.update(task, description=f'text encode section={i}') t0 = time.time() torch.manual_seed(seed) # shared.log.debug(f'FramePack: section={i} prompt="{prompt}"') shared.state.textinfo = 'Text encode' stream.output_queue.push(('progress', (None, 'Text encoding...'))) sd_models.apply_balanced_offload(shared.sd_model) framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt) llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) metadata['comment'] = prompt if cfg_scale > 1 and n_prompt is not None and len(n_prompt) > 0: llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2) else: llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler) llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512) llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512) sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('prompt', time.time()-t0) shared.state.end(jobid) return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n def latents_encode(input_image, end_image): jobid = shared.state.begin('VAE Encode') pbar.update(task, description='image encode') # shared.log.debug(f'FramePack: image encode init={input_image.shape} end={end_image.shape if end_image is not None else None}') t0 = time.time() torch.manual_seed(seed) stream.output_queue.push(('progress', (None, 'VAE encoding...'))) sd_models.apply_balanced_offload(shared.sd_model) if input_image is not None: input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1 input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None] start_latent = framepack_vae.vae_encode(input_image_pt) if start_weight < 1: noise = torch.randn_like(start_latent) start_latent = start_latent * start_weight + noise * (1 - start_weight) if end_image is not None: end_image_pt = torch.from_numpy(end_image).float() / 127.5 - 1 end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None] end_latent = framepack_vae.vae_encode(end_image_pt) else: end_latent = None sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('encode', time.time()-t0) shared.state.end(jobid) return start_latent, end_latent def vision_encode(input_image, end_image): pbar.update(task, description='vision encode') # shared.log.debug(f'FramePack: vision encode init={input_image.shape} end={end_image.shape if end_image is not None else None}') t0 = time.time() shared.state.textinfo = 'Vision encode' stream.output_queue.push(('progress', (None, 'Vision encoding...'))) sd_models.apply_balanced_offload(shared.sd_model) # siglip doesn't work with offload sd_models.move_model(feature_extractor, devices.device, force=True) sd_models.move_model(image_encoder, devices.device, force=True) preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype) image_encoder_output = image_encoder(**preprocessed) image_encoder_last_hidden_state = image_encoder_output.last_hidden_state if end_image is not None: preprocessed = feature_extractor.preprocess(images=end_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype) end_image_encoder_output = image_encoder(**preprocessed) end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('vision', time.time()-t0) return image_encoder_last_hidden_state def step_callback(d): if use_cfgzero and is_first_section and d['i'] == 0: d['denoised'] = d['denoised'] * 0 t_current = time.time() if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped: stream.output_queue.push(('progress', (None, 'Interrupted...'))) stream.output_queue.push(('end', None)) raise AssertionError('Interrupted...') if shared.state.paused: shared.log.debug('Sampling paused') while shared.state.paused: if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') time.sleep(0.1) nonlocal total_generated_frames, t_last t_preview = time.time() current_step = d['i'] + 1 shared.state.textinfo = '' shared.state.sampling_step = ((lattent_padding_loop-1) * steps) + current_step shared.state.sampling_steps = steps * len(latent_paddings) progress = shared.state.sampling_step / shared.state.sampling_steps total_generated_frames = int(max(0, total_generated_latent_frames * 4 - 3)) pbar.update(task, advance=1, description=f'its={1/(t_current-t_last):.2f} sample={d["i"]+1}/{steps} section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)}') desc = f'Step {shared.state.sampling_step}/{shared.state.sampling_steps} | Current {current_step}/{steps} | Section {lattent_padding_loop}/{len(latent_paddings)} | Progress {progress:.2%}' if use_preview: preview = framepack_vae.vae_decode(d['denoised'], 'Preview') stream.output_queue.push(('progress', (preview, desc))) else: stream.output_queue.push(('progress', (None, desc))) timer.process.add('preview', time.time() - t_preview) t_last = t_current try: with devices.inference_context(), pbar: t0 = time.time() height, width, _C = input_image.shape start_latent, end_latent = latents_encode(input_image, end_image) image_encoder_last_hidden_state = vision_encode(input_image, end_image) # Sample loop stream.output_queue.push(('progress', (None, 'Start sampling...'))) generator = torch.Generator("cpu").manual_seed(seed) if is_f1: history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu() else: history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=devices.dtype).cpu() history_pixels = None lattent_padding_loop = 0 last_prompt = None for latent_padding in latent_paddings: current_prompt = prompts[lattent_padding_loop] if current_prompt != last_prompt: llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n = text_encode(current_prompt, i=lattent_padding_loop+1) last_prompt = current_prompt sammplejob = shared.state.begin('Sample') lattent_padding_loop += 1 # shared.log.trace(f'FramePack: op=sample section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)} window={latent_window_size} size={num_frames}') if is_f1: is_first_section, is_last_section = False, False else: is_first_section, is_last_section = latent_padding == latent_paddings[0], latent_padding == 0 if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped: stream.output_queue.push(('end', None)) return if is_f1: indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0) clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1) clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2) clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2) else: latent_padding_size = latent_padding * latent_window_size indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) clean_latent_indices_pre, _blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) clean_latents_pre = start_latent.to(history_latents) clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) if end_image is not None and is_first_section: clean_latents_post = (clean_latents_post * start_weight / len(latent_paddings)) + (end_weight * end_latent.to(history_latents)) / (start_weight/len(latent_paddings) + end_weight) # pylint: disable=possibly-used-before-assignment clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) sd_models.apply_balanced_offload(shared.sd_model) transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps, rel_l1_thresh=shared.opts.teacache_thresh) t_sample = time.time() generated_latents = k_diffusion_hunyuan.sample_hunyuan( transformer=transformer, sampler='unipc', width=width, height=height, frames=num_frames, num_inference_steps=steps, real_guidance_scale=cfg_scale, distilled_guidance_scale=cfg_distilled, guidance_rescale=cfg_rescale, shift=shift if shift > 0 else None, generator=generator, prompt_embeds=llama_vec, # pylint: disable=possibly-used-before-assignment prompt_embeds_mask=llama_attention_mask, # pylint: disable=possibly-used-before-assignment prompt_poolers=clip_l_pooler, # pylint: disable=possibly-used-before-assignment negative_prompt_embeds=llama_vec_n, # pylint: disable=possibly-used-before-assignment negative_prompt_embeds_mask=llama_attention_mask_n, # pylint: disable=possibly-used-before-assignment negative_prompt_poolers=clip_l_pooler_n, # pylint: disable=possibly-used-before-assignment image_embeddings=image_encoder_last_hidden_state, latent_indices=latent_indices, clean_latents=clean_latents, clean_latent_indices=clean_latent_indices, clean_latents_2x=clean_latents_2x, clean_latent_2x_indices=clean_latent_2x_indices, clean_latents_4x=clean_latents_4x, clean_latent_4x_indices=clean_latent_4x_indices, device=devices.device, dtype=devices.dtype, callback=step_callback, ) if is_last_section: generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2) total_generated_latent_frames += int(generated_latents.shape[2]) if is_f1: history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2) real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :] else: history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('sample', time.time()-t_sample) shared.state.end(sammplejob) t_vae = time.time() if history_pixels is None: history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu() else: overlapped_frames = latent_window_size * 4 - 3 if is_f1: section_latent_frames = latent_window_size * 2 current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, -section_latent_frames:], vae_type=vae_type).cpu() history_pixels = utils.soft_append_bcthw(history_pixels, current_pixels, overlapped_frames) else: section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu() history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) sd_models.apply_balanced_offload(shared.sd_model) timer.process.add('vae', time.time()-t_vae) if is_last_section: break total_generated_frames, _video_filename = save_video( p=None, pixels=history_pixels, audio=None, binary=None, mp4_fps=mp4_fps, mp4_codec=mp4_codec, mp4_opt=mp4_opt, mp4_ext=mp4_ext, mp4_sf=mp4_sf, mp4_video=mp4_video, mp4_frames=mp4_frames, mp4_interpolate=mp4_interpolate, pbar=pbar, stream=stream, metadata=metadata, ) except AssertionError: shared.log.info('FramePack: interrupted') if shared.opts.keep_incomplete: save_video( p=None, pixels=history_pixels, audio=None, binary=None, mp4_fps=mp4_fps, mp4_codec=mp4_codec, mp4_opt=mp4_opt, mp4_ext=mp4_ext, mp4_sf=mp4_sf, mp4_video=mp4_video, mp4_frames=mp4_frames, mp4_interpolate=0, pbar=pbar, stream=stream, metadata=metadata, ) except Exception as e: shared.log.error(f'FramePack: {e}') errors.display(e, 'FramePack') sd_models.apply_balanced_offload(shared.sd_model) stream.output_queue.push(('end', None)) t1 = time.time() shared.log.info(f'Processed: frames={total_generated_frames} fps={total_generated_frames/(t1-t0):.2f} its={(shared.state.sampling_step)/(t1-t0):.2f} time={t1-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}') shared.state.end(videojob) ================================================ FILE: modules/framepack/framepack_wrappers.py ================================================ import os import re import random import numpy as np import torch import gradio as gr from PIL import Image from modules import shared, processing, timer, paths, extra_networks, progress, ui_video_vlm, call_queue from modules.video_models.video_utils import check_av from modules.framepack import framepack_install # pylint: disable=wrong-import-order from modules.framepack import framepack_load # pylint: disable=wrong-import-order from modules.framepack import framepack_worker # pylint: disable=wrong-import-order from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order tmp_dir = os.path.join(paths.data_path, 'tmp', 'framepack') git_dir = os.path.join(os.path.dirname(__file__), 'framepack') git_repo = 'https://github.com/lllyasviel/framepack' git_commit = 'c5d375661a2557383f0b8da9d11d14c23b0c4eaf' loaded_variant = None def prepare_image(image, resolution): from modules.framepack.pipeline.utils import resize_and_center_crop buckets = [ (416, 960), (448, 864), (480, 832), (512, 768), (544, 704), (576, 672), (608, 640), (640, 608), (672, 576), (704, 544), (768, 512), (832, 480), (864, 448), (960, 416), ] if isinstance(image, Image.Image): image = np.array(image) h, w, _c = image.shape min_metric = float('inf') scale_factor = resolution / 640.0 scaled_h, scaled_w = h, w for (bucket_h, bucket_w) in buckets: metric = abs(h * bucket_w - w * bucket_h) if metric <= min_metric: min_metric = metric scaled_h = round(bucket_h * scale_factor / 16) * 16 scaled_w = round(bucket_w * scale_factor / 16) * 16 image = resize_and_center_crop(image, target_height=scaled_h, target_width=scaled_w) h0, w0, _c = image.shape shared.log.debug(f'FramePack prepare: input="{w}x{h}" resized="{w0}x{h0}" resolution={resolution} scale={scale_factor}') return image def interpolate_prompts(prompts, steps): interpolated_prompts = [''] * steps if prompts is None: return interpolated_prompts if isinstance(prompts, str): prompts = re.split(r'[,\n]', prompts) prompts = [p.strip() for p in prompts] if len(prompts) == 0: return interpolated_prompts if len(prompts) == steps: return prompts factor = steps / len(prompts) for i in range(steps): prompt_index = int(i / factor) interpolated_prompts[i] = prompts[prompt_index] # shared.log.trace(f'FramePack interpolate: section={i} prompt="{interpolated_prompts[i]}"') return interpolated_prompts def prepare_prompts(p, init_image, prompt:str, section_prompt:str, num_sections:int, vlm_enhance:bool, vlm_model:str, vlm_system_prompt:str): section_prompts = interpolate_prompts(section_prompt, num_sections) p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles) p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) shared.prompt_styles.apply_styles_to_extra(p) p.prompts, p.network_data = extra_networks.parse_prompts([p.prompt]) extra_networks.activate(p) prompt = p.prompts[0] generated_prompts = [''] * num_sections previous_prompt = None for i in range(num_sections): current_prompt = (prompt + ' ' + section_prompts[i]).strip() if current_prompt == previous_prompt: generated_prompts[i] = generated_prompts[i - 1] else: generated_prompts[i] = ui_video_vlm.enhance_prompt( enable=vlm_enhance, model=vlm_model, image=init_image, prompt=current_prompt, system_prompt=vlm_system_prompt, ) previous_prompt = current_prompt return generated_prompts def load_model(variant, attention): global loaded_variant # pylint: disable=global-statement if (shared.sd_model_type != 'hunyuanvideo') or (loaded_variant != variant): yield gr.update(), gr.update(), 'Verifying FramePack' framepack_install.install_requirements(attention) # framepack_install.git_clone(git_repo=git_repo, git_dir=git_dir, tmp_dir=tmp_dir) # framepack_install.git_update(git_dir=git_dir, git_commit=git_commit) # sys.path.append(git_dir) framepack_hijack.set_progress_bar_config() yield gr.update(), gr.update(), 'Model loading...', '' loaded_variant = framepack_load.load_model(variant) if loaded_variant is not None: yield gr.update(), gr.update(), 'Model loaded' else: yield gr.update(), gr.update(), 'Model load failed' def unload_model(): shared.log.debug('FramePack unload') framepack_load.unload_model() yield gr.update(), gr.update(), 'Model unloaded' def run_framepack(task_id, _ui_state, init_image, end_image, start_weight, end_weight, vision_weight, prompt, system_prompt, optimized_prompt, section_prompt, negative_prompt, styles, seed, resolution, duration, latent_ws, steps, cfg_scale, cfg_distilled, cfg_rescale, shift, use_teacache, use_cfgzero, use_preview, mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, attention, vae_type, variant, vlm_enhance, vlm_model, vlm_system_prompt): variant = variant or 'bi-directional' if init_image is None: init_image = np.zeros((resolution, resolution, 3), dtype=np.uint8) mode = 't2v' elif end_image is not None: mode = 'flf2v' else: mode = 'i2v' av = check_av() if av is None: yield gr.update(), gr.update(), 'AV package not installed' return progress.add_task_to_queue(task_id) with call_queue.get_lock(): progress.start_task(task_id) yield from load_model(variant, attention) if shared.sd_model_type != 'hunyuanvideo': progress.finish_task(task_id) yield gr.update(), gr.update(), 'Model load failed' return yield gr.update(), gr.update(), 'Generate starting...' from modules.framepack.pipeline.thread_utils import AsyncStream, async_run framepack_worker.stream = AsyncStream() if seed is None or seed == '' or seed == -1: random.seed() seed = random.randrange(4294967294) seed = int(seed) torch.manual_seed(seed) num_sections = len(framepack_worker.get_latent_paddings(mp4_fps, mp4_interpolate, latent_ws, duration, variant)) num_frames = (latent_ws * 4 - 3) * num_sections + 1 shared.log.info(f'FramePack start: mode={mode} variant="{variant}" frames={num_frames} sections={num_sections} resolution={resolution} seed={seed} duration={duration} teacache={use_teacache} thres={shared.opts.teacache_thresh} cfgzero={use_cfgzero}') shared.log.info(f'FramePack params: steps={steps} start={start_weight} end={end_weight} vision={vision_weight} scale={cfg_scale} distilled={cfg_distilled} rescale={cfg_rescale} shift={shift}') init_image = prepare_image(init_image, resolution) if end_image is not None: end_image = prepare_image(end_image, resolution) w, h, _c = init_image.shape p = processing.StableDiffusionProcessingVideo( sd_model=shared.sd_model, prompt=prompt, negative_prompt=negative_prompt, styles=styles, steps=steps, seed=seed, width=w, height=h, ) p.ops.append('video') prompts = prepare_prompts(p, init_image, prompt, section_prompt, num_sections, vlm_enhance, vlm_model, vlm_system_prompt) async_run( framepack_worker.worker, init_image, end_image, start_weight, end_weight, vision_weight, prompts, p.negative_prompt, system_prompt, optimized_prompt, vlm_enhance, seed, duration, latent_ws, p.steps, cfg_scale, cfg_distilled, cfg_rescale, shift, use_teacache, use_cfgzero, use_preview, mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, vae_type, variant, ) output_filename = None while True: flag, data = framepack_worker.stream.output_queue.next() if flag == 'file': output_filename = data yield output_filename, gr.update(), gr.update() if flag == 'progress': preview, text = data summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ') memory = shared.mem_mon.summary() stats = f"

{summary} {memory}

" yield gr.update(), gr.update(value=preview), f'{text} {stats}' if flag == 'end': yield output_filename, gr.update(value=None), gr.update() break progress.finish_task(task_id) yield gr.update(), gr.update(), 'Generate finished' return ================================================ FILE: modules/framepack/pipeline/bucket_tools.py ================================================ bucket_options = { 640: [ (416, 960), (448, 864), (480, 832), (512, 768), (544, 704), (576, 672), (608, 640), (640, 608), (672, 576), (704, 544), (768, 512), (832, 480), (864, 448), (960, 416), ], } def find_nearest_bucket(h, w, resolution=640): min_metric = float('inf') best_bucket = None for (bucket_h, bucket_w) in bucket_options[resolution]: metric = abs(h * bucket_w - w * bucket_h) if metric <= min_metric: min_metric = metric best_bucket = (bucket_h, bucket_w) return best_bucket ================================================ FILE: modules/framepack/pipeline/clip_vision.py ================================================ import numpy as np def hf_clip_vision_encode(image, feature_extractor, image_encoder): assert isinstance(image, np.ndarray) assert image.ndim == 3 and image.shape[2] == 3 assert image.dtype == np.uint8 preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype) image_encoder_output = image_encoder(**preprocessed) return image_encoder_output ================================================ FILE: modules/framepack/pipeline/dit_common.py ================================================ import torch import accelerate.accelerator from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x def LayerNorm_forward(self, x): return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) LayerNorm.forward = LayerNorm_forward torch.nn.LayerNorm.forward = LayerNorm_forward def FP32LayerNorm_forward(self, x): origin_dtype = x.dtype return torch.nn.functional.layer_norm( x.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ).to(origin_dtype) FP32LayerNorm.forward = FP32LayerNorm_forward def RMSNorm_forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight is None: return hidden_states.to(input_dtype) return hidden_states.to(input_dtype) * self.weight.to(input_dtype) RMSNorm.forward = RMSNorm_forward def AdaLayerNormContinuous_forward(self, x, conditioning_embedding): emb = self.linear(self.silu(conditioning_embedding)) scale, shift = emb.chunk(2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward ================================================ FILE: modules/framepack/pipeline/hunyuan.py ================================================ import torch from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE from modules import devices @torch.no_grad() def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256): assert isinstance(prompt, str) prompt = [prompt] # LLAMA prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt] crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"] llama_inputs = tokenizer( prompt_llama, padding="max_length", max_length=max_length + crop_start, truncation=True, return_tensors="pt", return_length=False, return_overflowing_tokens=False, return_attention_mask=True, ) llama_input_ids = llama_inputs.input_ids.to(devices.device) llama_attention_mask = llama_inputs.attention_mask.to(devices.device) llama_attention_length = int(llama_attention_mask.sum()) llama_outputs = text_encoder( input_ids=llama_input_ids, attention_mask=llama_attention_mask, output_hidden_states=True, ) llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length] # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:] llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length] assert torch.all(llama_attention_mask.bool()) # CLIP clip_l_input_ids = tokenizer_2( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ).input_ids clip_l_pooler = text_encoder_2(clip_l_input_ids.to(devices.device), output_hidden_states=False).pooler_output return llama_vec, clip_l_pooler @torch.no_grad() def vae_decode_fake(latents): latent_rgb_factors = [ [-0.0395, -0.0331, 0.0445], [0.0696, 0.0795, 0.0518], [0.0135, -0.0945, -0.0282], [0.0108, -0.0250, -0.0765], [-0.0209, 0.0032, 0.0224], [-0.0804, -0.0254, -0.0639], [-0.0991, 0.0271, -0.0669], [-0.0646, -0.0422, -0.0400], [-0.0696, -0.0595, -0.0894], [-0.0799, -0.0208, -0.0375], [0.1166, 0.1627, 0.0962], [0.1165, 0.0432, 0.0407], [-0.2315, -0.1920, -0.1355], [-0.0270, 0.0401, -0.0821], [-0.0616, -0.0997, -0.0727], [0.0249, -0.0469, -0.1703] ] # From comfyui latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None] bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1) images = images.clamp(0.0, 1.0) return images @torch.no_grad() def vae_decode(latents, vae, image_mode=False): latents = latents / vae.config.scaling_factor if not image_mode: image = vae.decode(latents.to(device=devices.device, dtype=devices.dtype)).sample else: latents = latents.to(device=devices.device, dtype=devices.dtype).unbind(2) image = [vae.decode(l.unsqueeze(2)).sample for l in latents] image = torch.cat(image, dim=2) return image @torch.no_grad() def vae_encode(image, vae): latents = vae.encode(image.to(device=devices.device, dtype=devices.dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor return latents ================================================ FILE: modules/framepack/pipeline/hunyuan_video_packed.py ================================================ from typing import Optional, Tuple import torch import torch.nn as nn import einops import numpy as np from diffusers.loaders import FromOriginalModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.utils import logging from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from modules.framepack.pipeline.dit_common import LayerNorm enabled_backends = [] if torch.backends.cuda.flash_sdp_enabled(): enabled_backends.append("flash") if torch.backends.cuda.math_sdp_enabled(): enabled_backends.append("math") if torch.backends.cuda.mem_efficient_sdp_enabled(): enabled_backends.append("mem_efficient") if torch.backends.cuda.cudnn_sdp_enabled(): enabled_backends.append("cudnn") try: # raise NotImplementedError from xformers.ops import memory_efficient_attention as xformers_attn_func except Exception: xformers_attn_func = None try: # raise NotImplementedError from flash_attn import flash_attn_varlen_func, flash_attn_func except Exception: flash_attn_varlen_func = None flash_attn_func = None try: # raise NotImplementedError from sageattention import sageattn_varlen, sageattn except Exception: sageattn_varlen = None sageattn = None logger = logging.get_logger(__name__) # pylint: disable=invalid-name def pad_for_3d_conv(x, kernel_size): _b, _c, t, h, w = x.shape pt, ph, pw = kernel_size pad_t = (pt - (t % pt)) % pt pad_h = (ph - (h % ph)) % ph pad_w = (pw - (w % pw)) % pw return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate') def center_down_sample_3d(x, kernel_size): # pt, ph, pw = kernel_size # cp = (pt * ph * pw) // 2 # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) # xc = xp[cp] # return xc return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) def get_cu_seqlens(text_mask, img_len): batch_size = text_mask.shape[0] text_len = text_mask.sum(dim=1) max_len = text_mask.shape[1] + img_len cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") for i in range(batch_size): s = text_len[i] + img_len s1 = i * max_len + s s2 = (i + 1) * max_len cu_seqlens[2 * i + 1] = s1 cu_seqlens[2 * i + 2] = s2 return cu_seqlens def apply_rotary_emb_transposed(x, freqs_cis): cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) out = x.float() * cos + x_rotated.float() * sin out = out.to(x) return out def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: if sageattn is not None: x = sageattn(q, k, v, tensor_layout='NHD') return x if flash_attn_func is not None: x = flash_attn_func(q, k, v) return x if xformers_attn_func is not None: x = xformers_attn_func(q, k, v) return x x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) return x B, L, _H, _C = q.shape q = q.flatten(0, 1) k = k.flatten(0, 1) v = v.flatten(0, 1) if sageattn_varlen is not None: x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) elif flash_attn_varlen_func is not None: x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) else: raise NotImplementedError('No Attn Installed!') x = x.unflatten(0, (B, L)) return x class HunyuanAttnProcessorFlashAttnDouble: def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) query = apply_rotary_emb_transposed(query, image_rotary_emb) key = apply_rotary_emb_transposed(key, image_rotary_emb) encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([query, encoder_query], dim=1) key = torch.cat([key, encoder_key], dim=1) value = torch.cat([value, encoder_value], dim=1) hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) hidden_states = hidden_states.flatten(-2) txt_length = encoder_hidden_states.shape[1] hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states class HunyuanAttnProcessorFlashAttnSingle: def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) txt_length = encoder_hidden_states.shape[1] query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) hidden_states = hidden_states.flatten(-2) hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] return hidden_states, encoder_hidden_states class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, guidance, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) time_guidance_emb = timesteps_emb + guidance_emb pooled_projections = self.text_embedder(pooled_projection) conditioning = time_guidance_emb + pooled_projections return conditioning class CombinedTimestepTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning class HunyuanVideoAdaNorm(nn.Module): def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: super().__init__() out_features = out_features or 2 * in_features self.linear = nn.Linear(in_features, out_features) self.nonlinearity = nn.SiLU() def forward( self, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: temb = self.linear(self.nonlinearity(temb)) gate_msa, gate_mlp = temb.chunk(2, dim=-1) gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) return gate_msa, gate_mlp class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, bias=attention_bias, ) self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, ) gate_msa, gate_mlp = self.norm_out(temb) hidden_states = hidden_states + attn_output * gate_msa ff_output = self.ff(self.norm2(hidden_states)) hidden_states = hidden_states + ff_output * gate_mlp return hidden_states class HunyuanVideoIndividualTokenRefiner(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, num_layers: int, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() self.refiner_blocks = nn.ModuleList( [ HunyuanVideoIndividualTokenRefinerBlock( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, attention_bias=attention_bias, ) for _ in range(num_layers) ] ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> None: self_attn_mask = None if attention_mask is not None: batch_size = attention_mask.shape[0] seq_len = attention_mask.shape[1] attention_mask = attention_mask.to(hidden_states.device).bool() self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() self_attn_mask[:, :, :, 0] = True for block in self.refiner_blocks: hidden_states = block(hidden_states, temb, self_attn_mask) return hidden_states class HunyuanVideoTokenRefiner(nn.Module): def __init__( self, in_channels: int, num_attention_heads: int, attention_head_dim: int, num_layers: int, mlp_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=hidden_size, pooled_projection_dim=in_channels ) self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) self.token_refiner = HunyuanVideoIndividualTokenRefiner( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, num_layers=num_layers, mlp_width_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, attention_bias=attention_bias, ) def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) else: original_dtype = hidden_states.dtype mask_float = attention_mask.float().unsqueeze(-1) pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) temb = self.time_text_embed(timestep, pooled_projections) hidden_states = self.proj_in(hidden_states) hidden_states = self.token_refiner(hidden_states, temb, attention_mask) return hidden_states class HunyuanVideoRotaryPosEmbed(nn.Module): def __init__(self, rope_dim, theta): super().__init__() self.DT, self.DY, self.DX = rope_dim self.theta = theta @torch.no_grad() def get_frequency(self, dim, pos): T, H, W = pos.shape freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) return freqs.cos(), freqs.sin() @torch.no_grad() def forward_inner(self, frame_indices, height, width, device): GT, GY, GX = torch.meshgrid( frame_indices.to(device=device, dtype=torch.float32), torch.arange(0, height, device=device, dtype=torch.float32), torch.arange(0, width, device=device, dtype=torch.float32), indexing="ij" ) FCT, FST = self.get_frequency(self.DT, GT) FCY, FSY = self.get_frequency(self.DY, GY) FCX, FSX = self.get_frequency(self.DX, GX) result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) return result.to(device) @torch.no_grad() def forward(self, frame_indices, height, width, device): frame_indices = frame_indices.unbind(0) results = [self.forward_inner(f, height, width, device) for f in frame_indices] results = torch.stack(results, dim=0) return results class AdaLayerNormZero(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError(f"unknown norm_type {norm_type}") def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError(f"unknown norm_type {norm_type}") def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa class AdaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) scale, shift = emb.chunk(2, dim=-1) x = self.norm(x) * (1 + scale) + shift return x class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim mlp_dim = int(hidden_size * mlp_ratio) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=hidden_size, bias=True, processor=HunyuanAttnProcessorFlashAttnSingle(), qk_norm=qk_norm, eps=1e-6, pre_only=True, ) self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") self.proj_mlp = nn.Linear(hidden_size, mlp_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) residual = hidden_states # 1. Input normalization norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) norm_hidden_states, norm_encoder_hidden_states = ( norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :], ) # 2. Attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) attn_output = torch.cat([attn_output, context_attn_output], dim=1) # 3. Modulation and residual connection hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states = gate * self.proj_out(hidden_states) hidden_states = hidden_states + residual hidden_states, encoder_hidden_states = ( hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :], ) return hidden_states, encoder_hidden_states class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float, qk_norm: str = "rms_norm", ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, added_kv_proj_dim=hidden_size, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=hidden_size, context_pre_only=False, bias=True, processor=HunyuanAttnProcessorFlashAttnDouble(), qk_norm=qk_norm, eps=1e-6, ) self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) # 2. Joint attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=freqs_cis, ) # 3. Modulation and residual connection hidden_states = hidden_states + attn_output * gate_msa encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa norm_hidden_states = self.norm2(hidden_states) norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp # 4. Feed-forward ff_output = self.ff(norm_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states) hidden_states = hidden_states + gate_mlp * ff_output encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output return hidden_states, encoder_hidden_states class ClipVisionProjection(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.Linear(in_channels, out_channels * 3) self.down = nn.Linear(out_channels * 3, out_channels) def forward(self, x): projected_x = self.down(nn.functional.silu(self.up(x))) return projected_x class HunyuanVideoPatchEmbed(nn.Module): def __init__(self, patch_size, in_chans, embed_dim): super().__init__() self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): def __init__(self, inner_dim): super().__init__() self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) @torch.no_grad() def initialize_weight_from_another_conv3d(self, another_layer): weight = another_layer.weight.detach().clone() bias = another_layer.bias.detach().clone() sd = { 'proj.weight': weight.clone(), 'proj.bias': bias.clone(), 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0, 'proj_2x.bias': bias.clone(), 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0, 'proj_4x.bias': bias.clone(), } sd = {k: v.clone() for k, v in sd.items()} self.load_state_dict(sd) return class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): @register_to_config def __init__( self, in_channels: int = 16, out_channels: int = 16, num_attention_heads: int = 24, attention_head_dim: int = 128, num_layers: int = 20, num_single_layers: int = 40, num_refiner_layers: int = 2, mlp_ratio: float = 4.0, patch_size: int = 2, patch_size_t: int = 1, qk_norm: str = "rms_norm", guidance_embeds: bool = True, # pylint: disable=unused-argument text_embed_dim: int = 4096, pooled_projection_dim: int = 768, rope_theta: float = 256.0, rope_axes_dim: Tuple[int] = (16, 56, 56), has_image_proj=False, image_proj_dim=1152, has_clean_x_embedder=False, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels # 1. Latent and condition embedders self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) self.clean_x_embedder = None self.image_projection = None # 2. RoPE self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( [ HunyuanVideoTransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm ) for _ in range(num_layers) ] ) # 4. Single stream transformer blocks self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleTransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm ) for _ in range(num_single_layers) ] ) # 5. Output projection self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) self.inner_dim = inner_dim self.use_gradient_checkpointing = False self.enable_teacache = False if has_image_proj: self.install_image_projection(image_proj_dim) if has_clean_x_embedder: self.install_clean_x_embedder() self.high_quality_fp32_output_for_inference = False def install_image_projection(self, in_channels): self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim) self.config['has_image_proj'] = True self.config['image_proj_dim'] = in_channels def install_clean_x_embedder(self): self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) self.config['has_clean_x_embedder'] = True def enable_gradient_checkpointing(self): self.use_gradient_checkpointing = True def disable_gradient_checkpointing(self): self.use_gradient_checkpointing = False def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): self.enable_teacache = enable_teacache self.cnt = 0 self.num_steps = num_steps self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.previous_residual = None self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]) def gradient_checkpointing_method(self, block, *args): if self.use_gradient_checkpointing: result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) else: result = block(*args) return result def process_input_hidden_states( self, latents, latent_indices=None, clean_latents=None, clean_latent_indices=None, clean_latents_2x=None, clean_latent_2x_indices=None, clean_latents_4x=None, clean_latent_4x_indices=None ): hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) B, C, T, H, W = hidden_states.shape if latent_indices is None: latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) hidden_states = hidden_states.flatten(2).transpose(1, 2) rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) rope_freqs = rope_freqs.flatten(2).transpose(1, 2) if clean_latents is not None and clean_latent_indices is not None: clean_latents = clean_latents.to(hidden_states) clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) clean_latents = clean_latents.flatten(2).transpose(1, 2) clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) if clean_latents_2x is not None and clean_latent_2x_indices is not None: clean_latents_2x = clean_latents_2x.to(hidden_states) clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device) clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) if clean_latents_4x is not None and clean_latent_4x_indices is not None: clean_latents_4x = clean_latents_4x.to(hidden_states) clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device) clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) return hidden_states, rope_freqs def forward( self, hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance, latent_indices=None, clean_latents=None, clean_latent_indices=None, clean_latents_2x=None, clean_latent_2x_indices=None, clean_latents_4x=None, clean_latent_4x_indices=None, image_embeddings=None, attention_kwargs=None, return_dict=True ): if attention_kwargs is None: attention_kwargs = {} batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config['patch_size'], self.config['patch_size_t'] post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p original_context_length = post_patch_num_frames * post_patch_height * post_patch_width hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices) temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask) if self.image_projection is not None: assert image_embeddings is not None, 'You must use image embeddings!' extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device) # must cat before (not after) encoder_hidden_states, due to attn masking encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) if batch_size == 1: # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want # If they are not same, then their impls are wrong. Ours are always the correct one. text_len = encoder_attention_mask.sum().item() encoder_hidden_states = encoder_hidden_states[:, :text_len] attention_mask = None, None, None, None else: img_seq_len = hidden_states.shape[1] txt_seq_len = encoder_hidden_states.shape[1] cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) cu_seqlens_kv = cu_seqlens_q max_seqlen_q = img_seq_len + txt_seq_len max_seqlen_kv = max_seqlen_q attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv if self.enable_teacache: modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] if self.cnt == 0 or self.cnt == self.num_steps-1: should_calc = True self.accumulated_rel_l1_distance = 0 else: curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item() self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh if should_calc: self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp self.cnt += 1 if self.cnt == self.num_steps: self.cnt = 0 if not should_calc: hidden_states = hidden_states + self.previous_residual else: ori_hidden_states = hidden_states.clone() for _block_id, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) for _block_id, block in enumerate(self.single_transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) self.previous_residual = hidden_states - ori_hidden_states else: for _block_id, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) for _block_id, block in enumerate(self.single_transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) hidden_states = hidden_states[:, -original_context_length:, :] if self.high_quality_fp32_output_for_inference: hidden_states = hidden_states.to(dtype=torch.float32) if self.proj_out.weight.dtype != torch.float32: self.proj_out.to(dtype=torch.float32) hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)', t=post_patch_num_frames, h=post_patch_height, w=post_patch_width, pt=p_t, ph=p, pw=p) if return_dict: return Transformer2DModelOutput(sample=hidden_states) return hidden_states, ================================================ FILE: modules/framepack/pipeline/k_diffusion_hunyuan.py ================================================ import math import torch from modules.framepack.pipeline.uni_pc_fm import sample_unipc from modules.framepack.pipeline.wrapper import fm_wrapper from modules.framepack.pipeline.utils import repeat_to_batch_size def flux_time_shift(t, mu=1.15, sigma=1.0): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): k = (y2 - y1) / (x2 - x1) b = y1 - k * x1 mu = k * context_length + b mu = min(mu, math.log(exp_max)) return mu def get_flux_sigmas_from_mu(n, mu): sigmas = torch.linspace(1, 0, steps=n + 1) sigmas = flux_time_shift(sigmas, mu=mu) return sigmas @torch.inference_mode() def sample_hunyuan( transformer, sampler='unipc', initial_latent=None, concat_latent=None, strength=1.0, width=512, height=512, frames=16, real_guidance_scale=1.0, distilled_guidance_scale=6.0, guidance_rescale=0.0, shift=None, num_inference_steps=25, batch_size=None, generator=None, prompt_embeds=None, prompt_embeds_mask=None, prompt_poolers=None, negative_prompt_embeds=None, negative_prompt_embeds_mask=None, negative_prompt_poolers=None, dtype=torch.bfloat16, device=None, negative_kwargs=None, callback=None, **kwargs, ): device = device or transformer.device if batch_size is None: batch_size = int(prompt_embeds.shape[0]) latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) _B, _C, T, H, W = latents.shape seq_length = T * H * W // 4 if shift is None: mu = calculate_flux_mu(seq_length, exp_max=7.0) else: mu = math.log(shift) sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) k_model = fm_wrapper(transformer) if initial_latent is not None: sigmas = sigmas * strength first_sigma = sigmas[0].to(device=device, dtype=torch.float32) initial_latent = initial_latent.to(device=device, dtype=torch.float32) latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma if concat_latent is not None: concat_latent = concat_latent.to(latents) distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) concat_latent = repeat_to_batch_size(concat_latent, batch_size) sampler_kwargs = dict( dtype=dtype, cfg_scale=real_guidance_scale, cfg_rescale=guidance_rescale, concat_latent=concat_latent, positive=dict( pooled_projections=prompt_poolers, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_embeds_mask, guidance=distilled_guidance, **kwargs, ), negative=dict( pooled_projections=negative_prompt_poolers, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_embeds_mask, guidance=distilled_guidance, **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), ) ) if sampler == 'unipc': results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) else: raise NotImplementedError(f'Sampler {sampler} is not supported.') return results ================================================ FILE: modules/framepack/pipeline/thread_utils.py ================================================ import time from threading import Thread, Lock class Listener: task_queue = [] lock = Lock() thread = None @classmethod def _process_tasks(cls): while True: task = None with cls.lock: if cls.task_queue: task = cls.task_queue.pop(0) if task is None: time.sleep(0.001) continue func, args, kwargs = task try: func(*args, **kwargs) except Exception as e: print(f"Error in listener thread: {e}") @classmethod def add_task(cls, func, *args, **kwargs): with cls.lock: cls.task_queue.append((func, args, kwargs)) if cls.thread is None: cls.thread = Thread(target=cls._process_tasks, daemon=True) cls.thread.start() def async_run(func, *args, **kwargs): Listener.add_task(func, *args, **kwargs) class FIFOQueue: def __init__(self): self.queue = [] self.lock = Lock() def push(self, item): with self.lock: self.queue.append(item) def pop(self): with self.lock: if self.queue: return self.queue.pop(0) return None def top(self): with self.lock: if self.queue: return self.queue[0] return None def next(self): while True: with self.lock: if self.queue: return self.queue.pop(0) time.sleep(0.001) class AsyncStream: def __init__(self): self.input_queue = FIFOQueue() self.output_queue = FIFOQueue() ================================================ FILE: modules/framepack/pipeline/uni_pc_fm.py ================================================ # Better Flow Matching UniPC by Lvmin Zhang # (c) 2025 # CC BY-SA 4.0 # Attribution-ShareAlike 4.0 International Licence import torch import numpy as np from tqdm.auto import trange def expand_dims(v, dims): return v[(...,) + (None,) * (dims - 1)] torch_linalg_solve = None def test_solver(): from modules import devices, shared try: a = torch.randn(50, 50).to(device=devices.device, dtype=torch.float32) b = torch.randn(50, 2).to(device=devices.device, dtype=torch.float32) _x = torch.linalg.solve(a, b) return True except Exception as e: shared.log.debug(f'FramePack: solver=cpu {e}') return False def linalg_solve(A, B, device): global torch_linalg_solve # pylint: disable=global-statement if torch_linalg_solve is None: torch_linalg_solve = test_solver() if torch_linalg_solve: X = torch.linalg.solve(A, B) return X else: A_np = A.float().cpu().numpy() B_np = B.float().cpu().numpy() X_np = np.linalg.solve(A_np, B_np) X = torch.from_numpy(X_np).to(device=device, dtype=A.dtype) return X class FlowMatchUniPC: def __init__(self, model, extra_args, variant='bh1'): self.model = model self.variant = variant self.extra_args = extra_args def model_fn(self, x, t): return self.model(x, t, **self.extra_args) def update_fn(self, x, model_prev_list, t_prev_list, t, order): assert order <= len(model_prev_list) dims = x.dim() t_prev_0 = t_prev_list[-1] lambda_prev_0 = - torch.log(t_prev_0) lambda_t = - torch.log(t) model_prev_0 = model_prev_list[-1] h = lambda_t - lambda_prev_0 rks = [] D1s = [] for i in range(1, order): t_prev_i = t_prev_list[-(i + 1)] model_prev_i = model_prev_list[-(i + 1)] lambda_prev_i = - torch.log(t_prev_i) rk = ((lambda_prev_i - lambda_prev_0) / h)[0] rks.append(rk) D1s.append((model_prev_i - model_prev_0) / rk) rks.append(1.) rks = torch.tensor(rks, device=x.device) R = [] b = [] hh = -h[0] h_phi_1 = torch.expm1(hh) h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.variant == 'bh1': B_h = hh elif self.variant == 'bh2': B_h = torch.expm1(hh) else: raise NotImplementedError('Bad variant!') for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= (i + 1) h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=x.device) use_predictor = len(D1s) > 0 if use_predictor: D1s = torch.stack(D1s, dim=1) if order == 2: rhos_p = torch.tensor([0.5], device=b.device) else: rhos_p = linalg_solve(R[:-1, :-1], b[:-1], x.device) else: D1s = None rhos_p = None if order == 1: rhos_c = torch.tensor([0.5], device=b.device) else: rhos_c = linalg_solve(R, b, x.device) x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0 if use_predictor: pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) else: pred_res = 0 x_t = x_t_ - expand_dims(B_h, dims) * pred_res model_t = self.model_fn(x_t, t) if D1s is not None: corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) else: corr_res = 0 D1_t = model_t - model_prev_0 x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t) return x_t, model_t def sample(self, x, sigmas, callback=None, disable_pbar=False): order = min(3, len(sigmas) - 2) model_prev_list, t_prev_list = [], [] for i in trange(len(sigmas) - 1, disable=disable_pbar): vec_t = sigmas[i].expand(x.shape[0]) if i == 0: model_prev_list = [self.model_fn(x, vec_t)] t_prev_list = [vec_t] elif i < order: init_order = i x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order) model_prev_list.append(model_x) t_prev_list.append(vec_t) else: x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order) model_prev_list.append(model_x) t_prev_list.append(vec_t) model_prev_list = model_prev_list[-order:] t_prev_list = t_prev_list[-order:] if callback is not None: callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) return model_prev_list[-1] def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): assert variant in ['bh1', 'bh2'] return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable) ================================================ FILE: modules/framepack/pipeline/utils.py ================================================ import os import json import random import glob import datetime import torch import einops import cv2 import numpy as np import torchvision from PIL import Image, ImageDraw, ImageFont def min_resize(x, m): if x.shape[0] < x.shape[1]: s0 = m s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) else: s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) s1 = m new_max = max(s1, s0) raw_max = max(x.shape[0], x.shape[1]) if new_max < raw_max: interpolation = cv2.INTER_AREA else: interpolation = cv2.INTER_LANCZOS4 y = cv2.resize(x, (s1, s0), interpolation=interpolation) return y def d_resize(x, y): H, W, _C = y.shape new_min = min(H, W) raw_min = min(x.shape[0], x.shape[1]) if new_min < raw_min: interpolation = cv2.INTER_AREA else: interpolation = cv2.INTER_LANCZOS4 y = cv2.resize(x, (W, H), interpolation=interpolation) return y def resize_and_center_crop(image, target_width, target_height): if target_height == image.shape[0] and target_width == image.shape[1]: return image pil_image = Image.fromarray(image) original_width, original_height = pil_image.size scale_factor = max(target_width / original_width, target_height / original_height) resized_width = int(round(original_width * scale_factor)) resized_height = int(round(original_height * scale_factor)) resized_image = pil_image.resize((resized_width, resized_height), Image.Resampling.LANCZOS) left = (resized_width - target_width) / 2 top = (resized_height - target_height) / 2 right = (resized_width + target_width) / 2 bottom = (resized_height + target_height) / 2 cropped_image = resized_image.crop((left, top, right, bottom)) return np.array(cropped_image) def resize_and_center_crop_pytorch(image, target_width, target_height): _B, _C, H, W = image.shape if H == target_height and W == target_width: return image scale_factor = max(target_width / W, target_height / H) resized_width = int(round(W * scale_factor)) resized_height = int(round(H * scale_factor)) resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False) top = (resized_height - target_height) // 2 left = (resized_width - target_width) // 2 cropped = resized[:, :, top:top + target_height, left:left + target_width] return cropped def resize_without_crop(image, target_width, target_height): if target_height == image.shape[0] and target_width == image.shape[1]: return image pil_image = Image.fromarray(image) resized_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS) return np.array(resized_image) def just_crop(image, w, h): if h == image.shape[0] and w == image.shape[1]: return image original_height, original_width = image.shape[:2] k = min(original_height / h, original_width / w) new_width = int(round(w * k)) new_height = int(round(h * k)) x_start = (original_width - new_width) // 2 y_start = (original_height - new_height) // 2 cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width] return cropped_image def write_to_json(data, file_path): temp_file_path = file_path + ".tmp" with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: json.dump(data, temp_file, indent=4) os.replace(temp_file_path, file_path) return def read_from_json(file_path): with open(file_path, 'rt', encoding='utf-8') as file: data = json.load(file) return data def get_active_parameters(m): return {k: v for k, v in m.named_parameters() if v.requires_grad} def cast_training_params(m, dtype=torch.float32): result = {} for n, param in m.named_parameters(): if param.requires_grad: param.data = param.to(dtype) result[n] = param return result def separate_lora_AB(parameters, B_patterns=None): parameters_normal = {} parameters_B = {} if B_patterns is None: B_patterns = ['.lora_B.', '__zero__'] for k, v in parameters.items(): if any(B_pattern in k for B_pattern in B_patterns): parameters_B[k] = v else: parameters_normal[k] = v return parameters_normal, parameters_B def set_attr_recursive(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], value) return @torch.no_grad() def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): batch_size = a.size(0) if b is None: b = torch.zeros_like(a) if mask_a is None: mask_a = torch.rand(batch_size) < probability_a mask_a = mask_a.to(a.device) mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) result = torch.where(mask_a, a, b) return result @torch.no_grad() def zero_module(module): for p in module.parameters(): p.detach().zero_() return module @torch.no_grad() def supress_lower_channels(m, k, alpha=0.01): data = m.weight.data.clone() assert int(data.shape[1]) >= k data[:, :k] = data[:, :k] * alpha m.weight.data = data.contiguous().clone() return m def freeze_module(m): if not hasattr(m, '_forward_inside_frozen_module'): m._forward_inside_frozen_module = m.forward # pylint: disable=protected-access m.requires_grad_(False) m.forward = torch.no_grad()(m.forward) return m def get_latest_safetensors(folder_path): safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors')) if not safetensors_files: raise ValueError('No file to resume!') latest_file = max(safetensors_files, key=os.path.getmtime) latest_file = os.path.abspath(os.path.realpath(latest_file)) return latest_file def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): tags = tags_str.split(', ') tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) prompt = ', '.join(tags) return prompt def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) if round_to_int: numbers = np.round(numbers).astype(int) return numbers.tolist() def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): edges = np.linspace(0, 1, n + 1) points = np.random.uniform(edges[:-1], edges[1:]) numbers = inclusive + (exclusive - inclusive) * points if round_to_int: numbers = np.round(numbers).astype(int) return numbers.tolist() def soft_append_bcthw(history, current, overlap=0): if overlap <= 0: return torch.cat([history, current], dim=2) assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) return output.to(history) def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0): b, _c, _t, _h, _w = x.shape per_row = b for p in [6, 5, 4, 3, 2]: if b % p == 0: per_row = p break os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))}) return x def save_bcthw_as_png(x, output_filename): os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') torchvision.io.write_png(x, output_filename) return output_filename def save_bchw_as_png(x, output_filename): os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, 'b c h w -> c h (b w)') torchvision.io.write_png(x, output_filename) return output_filename def add_tensors_with_padding(tensor1, tensor2): if tensor1.shape == tensor2.shape: return tensor1 + tensor2 shape1 = tensor1.shape shape2 = tensor2.shape new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) padded_tensor1 = torch.zeros(new_shape) padded_tensor2 = torch.zeros(new_shape) padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 result = padded_tensor1 + padded_tensor2 return result def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18): txt = Image.new("RGB", (width, height), color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype(font_path, size=size) if text == '': return np.array(txt) # Split text into lines that fit within the image width lines = [] words = text.split() current_line = words[0] for word in words[1:]: line_with_word = f"{current_line} {word}" if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: current_line = line_with_word else: lines.append(current_line) current_line = word lines.append(current_line) # Draw the text line by line y = 0 line_height = draw.textbbox((0, 0), "A", font=font)[3] for line in lines: if y + line_height > height: break # stop drawing if the next line will be outside the image draw.text((0, y), line, fill="black", font=font) y += line_height return np.array(txt) def blue_mark(x): x = x.copy() c = x[:, :, 2] b = cv2.blur(c, (9, 9)) x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) return x def green_mark(x): x = x.copy() x[:, :, 2] = -1 x[:, :, 0] = -1 return x def frame_mark(x): x = x.copy() x[:64] = -1 x[-64:] = -1 x[:, :8] = 1 x[:, -8:] = 1 return x @torch.inference_mode() def pytorch2numpy(imgs): results = [] for x in imgs: y = x.movedim(0, -1) y = y * 127.5 + 127.5 y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) results.append(y) return results @torch.inference_mode() def numpy2pytorch(imgs): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 h = h.movedim(-1, 1) return h @torch.no_grad() def duplicate_prefix_to_suffix(x, count, zero_out=False): if zero_out: return torch.cat([x, torch.zeros_like(x[:count])], dim=0) else: return torch.cat([x, x[:count]], dim=0) def weighted_mse(a, b, weight): return torch.mean(weight.float() * (a.float() - b.float()) ** 2) def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): x = (x - x_min) / (x_max - x_min) x = max(0.0, min(x, 1.0)) x = x ** sigma return y_min + x * (y_max - y_min) def expand_to_dims(x, target_dims): return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): if tensor is None: return None first_dim = tensor.shape[0] if first_dim == batch_size: return tensor if batch_size % first_dim != 0: raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") repeat_times = batch_size // first_dim return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) def dim5(x): return expand_to_dims(x, 5) def dim4(x): return expand_to_dims(x, 4) def dim3(x): return expand_to_dims(x, 3) def crop_or_pad_yield_mask(x, length): B, F, C = x.shape device = x.device dtype = x.dtype if F < length: y = torch.zeros((B, length, C), dtype=dtype, device=device) mask = torch.zeros((B, length), dtype=torch.bool, device=device) y[:, :F, :] = x mask[:, :F] = True return y, mask return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) def extend_dim(x, dim, minimal_length, zero_pad=False): original_length = int(x.shape[dim]) if original_length >= minimal_length: return x if zero_pad: padding_shape = list(x.shape) padding_shape[dim] = minimal_length - original_length padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) else: idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) last_element = x[idx] padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) return torch.cat([x, padding], dim=dim) def lazy_positional_encoding(t, repeats=None): if not isinstance(t, list): t = [t] from diffusers.models.embeddings import get_timestep_embedding te = torch.tensor(t) te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) if repeats is None: return te te = te[:, None, :].expand(-1, repeats, -1) return te def state_dict_offset_merge(A, B, C=None): result = {} keys = A.keys() for key in keys: A_value = A[key] B_value = B[key].to(A_value) if C is None: result[key] = A_value + B_value else: C_value = C[key].to(A_value) result[key] = A_value + B_value - C_value return result def state_dict_weighted_merge(state_dicts, weights): if len(state_dicts) != len(weights): raise ValueError("Number of state dictionaries must match number of weights") if not state_dicts: return {} total_weight = sum(weights) if total_weight == 0: raise ValueError("Sum of weights cannot be zero") normalized_weights = [w / total_weight for w in weights] keys = state_dicts[0].keys() result = {} for key in keys: result[key] = state_dicts[0][key] * normalized_weights[0] for i in range(1, len(state_dicts)): state_dict_value = state_dicts[i][key].to(result[key]) result[key] += state_dict_value * normalized_weights[i] return result def group_files_by_folder(all_files): grouped_files = {} for file in all_files: folder_name = os.path.basename(os.path.dirname(file)) if folder_name not in grouped_files: grouped_files[folder_name] = [] grouped_files[folder_name].append(file) list_of_lists = list(grouped_files.values()) return list_of_lists def generate_timestamp(): now = datetime.datetime.now() timestamp = now.strftime('%y%m%d_%H%M%S') milliseconds = f"{int(now.microsecond / 1000):03d}" random_number = random.randint(0, 9999) return f"{timestamp}_{milliseconds}_{random_number}" def write_PIL_image_with_png_info(image, metadata, path): from PIL.PngImagePlugin import PngInfo png_info = PngInfo() for key, value in metadata.items(): png_info.add_text(key, value) image.save(path, "PNG", pnginfo=png_info) return image def torch_safe_save(content, path): torch.save(content, path + '_tmp') os.replace(path + '_tmp', path) return path def move_optimizer_to_device(optimizer, device): for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) ================================================ FILE: modules/framepack/pipeline/wrapper.py ================================================ import torch def append_dims(x, target_dims): return x[(...,) + (None,) * (target_dims - x.ndim)] def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0): if guidance_rescale == 0: return noise_cfg std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg return noise_cfg def fm_wrapper(transformer, t_scale=1000.0): def k_model(x, sigma, **extra_args): dtype = extra_args['dtype'] cfg_scale = extra_args['cfg_scale'] cfg_rescale = extra_args['cfg_rescale'] concat_latent = extra_args['concat_latent'] original_dtype = x.dtype sigma = sigma.float() x = x.to(dtype) timestep = (sigma * t_scale).to(dtype) if concat_latent is None: hidden_states = x else: hidden_states = torch.cat([x, concat_latent.to(x)], dim=1) pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float() if cfg_scale == 1.0: pred_negative = torch.zeros_like(pred_positive) else: pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float() pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative) pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale) x0 = x.float() - pred.float() * append_dims(sigma, x.ndim) return x0.to(dtype=original_dtype) return k_model ================================================ FILE: modules/generation_parameters_copypaste.py ================================================ from __future__ import annotations import base64 import io import os from PIL import Image import gradio as gr from modules import shared, gr_tempdir, script_callbacks, images from modules.infotext import parse, mapping, quote, unquote # pylint: disable=unused-import type_of_gr_update = type(gr.update()) paste_fields: dict[str, dict] = {} field_names = {} registered_param_bindings: list[ParamBinding] = [] debug = shared.log.trace if os.environ.get('SD_PASTE_DEBUG', None) is not None else lambda *args, **kwargs: None debug('Trace: PASTE') parse_generation_parameters = parse # compatibility infotext_to_setting_name_mapping = mapping # compatibility # Mapping of aliases to metadata parameter names, populated automatically from component labels/elem_ids # This allows users to use component labels, elem_ids, or metadata names in the "skip params" setting param_aliases: dict[str, str] = {} class ParamBinding: def __init__(self, paste_button, tabname: str, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): self.paste_button = paste_button self.tabname = tabname self.source_text_component = source_text_component self.source_image_component = source_image_component self.source_tabname = source_tabname self.override_settings_component = override_settings_component self.paste_field_names = paste_field_names or [] # debug(f'ParamBinding: {vars(self)}') def reset(): paste_fields.clear() field_names.clear() def image_from_url_text(filedata): if filedata is None: return None if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] if type(filedata) == dict and filedata.get("is_file", False): filename = filedata["name"] is_in_right_dir = gr_tempdir.check_tmp_file(shared.demo, filename) if is_in_right_dir: filename = filename.rsplit('?', 1)[0] if not os.path.exists(filename): shared.log.error(f'Image file not found: {filename}') image = Image.new('RGB', (512, 512)) image.info['parameters'] = f'Image file not found: {filename}' return image image = Image.open(filename) geninfo, _items = images.read_info_from_image(image) image.info['parameters'] = geninfo return image else: shared.log.warning(f'File access denied: {filename}') return None if type(filedata) == list: if len(filedata) == 0: return None filedata = filedata[0] if not isinstance(filedata, str): shared.log.warning('Incorrect filedata received') return None if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] if filedata.startswith("data:image/webp;base64,"): filedata = filedata[len("data:image/webp;base64,"):] if filedata.startswith("data:image/jpeg;base64,"): filedata = filedata[len("data:image/jpeg;base64,"):] if filedata.startswith("data:image/jxl;base64,"): filedata = filedata[len("data:image/jxl;base64,"):] filebytes = base64.decodebytes(filedata.encode('utf-8')) image = Image.open(io.BytesIO(filebytes)) image.load() # images.read_info_from_image(image) return image def add_paste_fields(tabname: str, init_img: gr.Image | gr.HTML | None, fields: list[tuple[gr.components.Component, str]] | None, override_settings_component=None): paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} try: field_names[tabname] = [f[1] for f in fields if f[1] is not None and not callable(f[1])] if fields is not None else [] # tuple (component, label) except Exception as e: shared.log.error(f"Paste fields: tab={tabname} fields={fields} {e}") field_names[tabname] = [] # Build param_aliases automatically from component labels and elem_ids if fields is not None: for component, metadata_name in fields: if metadata_name is None or callable(metadata_name): continue metadata_lower = metadata_name.lower() # Extract label from component (e.g., "Batch size" -> maps to "Batch-2") label = getattr(component, 'label', None) if label and isinstance(label, str): label_lower = label.lower() if label_lower != metadata_lower and label_lower not in param_aliases: param_aliases[label_lower] = metadata_lower # Extract elem_id and derive variable name (e.g., "txt2img_batch_size" -> "batch_size") elem_id = getattr(component, 'elem_id', None) if elem_id and isinstance(elem_id, str): # Strip common prefixes like "txt2img_", "img2img_", "control_" var_name = elem_id for prefix in ['txt2img_', 'img2img_', 'control_', 'video_', 'extras_']: if var_name.startswith(prefix): var_name = var_name[len(prefix):] break var_name_lower = var_name.lower() if var_name_lower != metadata_lower and var_name_lower not in param_aliases: param_aliases[var_name_lower] = metadata_lower # backwards compatibility for existing extensions debug(f'Paste fields: tab={tabname} fields={field_names[tabname]}') debug(f'All fields: {get_all_fields()}') debug(f'Param aliases: {param_aliases}') import modules.ui if tabname == 'txt2img': modules.ui.txt2img_paste_fields = fields # compatibility elif tabname == 'img2img': modules.ui.img2img_paste_fields = fields # compatibility elif tabname == 'control': modules.ui.control_paste_fields = fields elif tabname == 'video': modules.ui.video_paste_fields = fields def get_all_fields(): all_fields = [] for _tab, fields in field_names.items(): for field in fields: field = field.replace('-1', '').replace('-2', '').lower() if field not in all_fields: all_fields.append(field) return all_fields def create_buttons(tabs_list: list[str]) -> dict[str, gr.Button]: buttons = {} for tab in tabs_list: name = tab if name == 'txt2img': name = 'Text' elif name == 'img2img': name = 'Image' elif name == 'inpaint': name = 'Inpaint' elif name == 'extras': name = 'Process' elif name == 'control': name = 'Control' elif name == 'caption': name = 'Caption' buttons[tab] = gr.Button(f"➠ {name}", elem_id=f"{tab}_tab") return buttons def should_skip(param: str): skip_params = [p.strip().lower() for p in shared.opts.disable_apply_params.split(",")] if not shared.opts.clip_skip_enabled: skip_params += ['clip skip'] # Expand skip_params with aliases (e.g., "batch_size" -> "batch-2") expanded_skip = set(skip_params) for skip in skip_params: if skip in param_aliases: expanded_skip.add(param_aliases[skip]) # Check if param should be skipped param_lower = param.lower() # Also check normalized name (without -1/-2) so "batch" skips both "batch-1" and "batch-2" param_normalized = param_lower.replace('-1', '').replace('-2', '') all_params = [p.lower() for p in get_all_fields()] valid = any(p in all_params for p in skip_params) skip = param_lower in expanded_skip or param_normalized in expanded_skip debug(f'Check: param="{param}" valid={valid} skip={skip} expanded={expanded_skip}') return skip def register_paste_params_button(binding: ParamBinding): registered_param_bindings.append(binding) def connect_paste_params_buttons(): binding: ParamBinding for binding in registered_param_bindings: if binding.tabname not in paste_fields: debug(f"Not not registered: tab={binding.tabname}") continue fields: list[tuple[gr.components.Component, str]] = paste_fields[binding.tabname]["fields"] destination_image_component = paste_fields[binding.tabname]["init_img"] if binding.source_image_component: if isinstance(destination_image_component, gr.Image): binding.paste_button.click( _js="extract_image_from_gallery" if isinstance(binding.source_image_component, gr.Gallery) else None, fn=send_image, inputs=[binding.source_image_component], outputs=[destination_image_component], show_progress='hidden', ) elif isinstance(destination_image_component, gr.HTML): # kanvas binding.paste_button.click( _js="send_to_kanvas", fn=None, inputs=[binding.source_image_component], outputs=[], show_progress='hidden', ) override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] if binding.source_text_component is not None and fields is not None: connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) if binding.source_tabname is not None and fields is not None and binding.source_tabname in paste_fields: paste_field_names = ['Prompt', 'Negative prompt', 'Steps'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names if "fields" in paste_fields[binding.source_tabname] and paste_fields[binding.source_tabname]["fields"] is not None: binding.paste_button.click( fn=lambda *x: x, inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names], ) binding.paste_button.click( fn=None, _js=f"switch_to_{binding.tabname}", inputs=[], outputs=[], show_progress='hidden', ) def send_image(x): image = x if isinstance(x, Image.Image) else image_from_url_text(x) return image def create_override_settings_dict(text_pairs): res = {} params = {} for pair in text_pairs: k, v = pair.split(":", maxsplit=1) params[k] = v.strip() for param_name, setting_name in mapping: value = params.get(param_name, None) if value is None: continue res[setting_name] = shared.opts.cast_value(setting_name, value) return res def connect_paste(button, local_paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): from modules.paths import params_path if prompt is None or len(prompt.strip()) == 0: if os.path.exists(params_path): with open(params_path, "r", encoding="utf8") as file: prompt = file.read() shared.log.debug(f'Prompt parse: type="params" prompt="{prompt}"') else: prompt = '' else: shared.log.debug(f'Prompt parse: type="current" prompt="{prompt}"') params = parse(prompt) script_callbacks.infotext_pasted_callback(prompt, params) res = [] applied = {} skipped = {} for output, key in local_paste_fields: if callable(key): v = key(params) else: v = params.get(key, None) if v is None: res.append(gr.update()) # triggers update for each gradio component even if there are no updates elif isinstance(v, type_of_gr_update): res.append(v) applied[key] = v else: if isinstance(v, str) and v.strip() == '' and key in {'Prompt', 'Negative prompt'}: debug(f'Paste skip empty: "{key}"') res.append(gr.update()) skipped[key] = v continue if should_skip(key): debug(f'Paste skip: "{key}"="{v}"') res.append(gr.update()) skipped[key] = v continue try: valtype = type(output.value) if hasattr(output, "step") and type(output.step) == float: valtype = float debug(f'Paste: "{key}"="{v}" type={valtype} var={vars(output)}') if valtype == bool: val = False if str(v).lower() == "false" else True elif valtype == list: val = v if isinstance(v, list) else [item.strip() for item in v.split(',')] else: val = valtype(v) res.append(gr.update(value=val)) applied[key] = val except Exception as e: shared.log.error(f'Paste param: key="{key}" value="{v}" error="{e}"') res.append(gr.update()) list_applied = [{k: v} for k, v in applied.items() if not callable(v) and not callable(k)] shared.log.debug(f"Prompt restore: apply={list_applied} skip={skipped}") return res if override_settings_component is not None: def paste_settings(params): params.pop('Prompt', None) params.pop('Negative prompt', None) if not params: gr.Dropdown.update(value=[], choices=[], visible=False) vals = {} for param_name, setting_name in infotext_to_setting_name_mapping: v = params.get(param_name, None) if v is None: continue if setting_name == 'sd_backend': continue if setting_name in shared.opts.disable_apply_metadata: continue if should_skip(param_name) or should_skip(setting_name): continue v = shared.opts.cast_value(setting_name, v) current_value = getattr(shared.opts, setting_name, None) if v == current_value: continue if type(current_value) == str and v == os.path.splitext(current_value)[0]: continue vals[param_name] = v vals_pairs = [f"{k}: {v}" for k, v in vals.items()] if len(vals_pairs) > 0: shared.log.debug(f'Settings overrides: {vals_pairs}') return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) local_paste_fields = local_paste_fields + [(override_settings_component, paste_settings)] button.click( fn=paste_func, inputs=[input_comp], outputs=[x[0] for x in local_paste_fields], show_progress='hidden', ) button.click( fn=None, _js=f"recalculate_prompts_{tabname}", inputs=[], outputs=[], show_progress='hidden', ) ================================================ FILE: modules/ggml/__init__.py ================================================ import os import time import torch import diffusers import transformers def install_gguf(): # pip install git+https://github.com/junejae/transformers@feature/t5-gguf # https://github.com/ggerganov/llama.cpp/issues/9566 from installer import install install('gguf', quiet=True) import importlib.metadata import gguf from modules import shared scripts_dir = os.path.join(os.path.dirname(gguf.__file__), '..', 'scripts') if os.path.exists(scripts_dir): os.rename(scripts_dir, scripts_dir + str(time.time())) # monkey patch transformers/diffusers so they detect newly installed gguf pacakge correctly ver = importlib.metadata.version('gguf') transformers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access transformers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access diffusers.utils.import_utils._is_gguf_available = True # pylint: disable=protected-access diffusers.utils.import_utils._gguf_version = ver # pylint: disable=protected-access shared.log.debug(f'Load GGUF: version={ver}') return gguf def load_gguf_state_dict(path: str, compute_dtype: torch.dtype) -> dict: gguf = install_gguf() from .gguf_utils import TORCH_COMPATIBLE_QTYPES from .gguf_tensor import GGMLTensor sd: dict[str, GGMLTensor] = {} stats = {} reader = gguf.GGUFReader(path) for tensor in reader.tensors: torch_tensor = torch.from_numpy(tensor.data) shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES: torch_tensor = torch_tensor.view(*shape) sd[tensor.name] = GGMLTensor(torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype) if tensor.tensor_type.name not in stats: stats[tensor.tensor_type.name] = 0 stats[tensor.tensor_type.name] += 1 return sd, stats def load_gguf(path, cls, compute_dtype: torch.dtype): _gguf = install_gguf() loader = cls.from_single_file if hasattr(cls, 'from_single_file') else cls.from_pretrained module = loader( path, quantization_config = diffusers.GGUFQuantizationConfig(compute_dtype=compute_dtype), torch_dtype=compute_dtype, ) module.gguf = 'gguf' return module ================================================ FILE: modules/ggml/gguf_tensor.py ================================================ # Original: invokeai.backend.quantization.gguf.ggml_tensor from typing import overload import torch import gguf from .gguf_utils import DEQUANTIZE_FUNCTIONS, TORCH_COMPATIBLE_QTYPES, dequantize def dequantize_and_run(func, args, kwargs): """A helper function for running math ops on GGMLTensor inputs. Dequantizes the inputs, and runs the function. """ dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args] dequantized_kwargs = { k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items() } return func(*dequantized_args, **dequantized_kwargs) def apply_to_quantized_tensor(func, args, kwargs): """A helper function to apply a function to a quantized GGML tensor, and re-wrap the result in a GGMLTensor. Assumes that the first argument is a GGMLTensor. """ # We expect the first argument to be a GGMLTensor, and all other arguments to be non-GGMLTensors. ggml_tensor = args[0] assert isinstance(ggml_tensor, GGMLTensor) assert all(not isinstance(a, GGMLTensor) for a in args[1:]) assert all(not isinstance(v, GGMLTensor) for v in kwargs.values()) new_data = func(ggml_tensor.quantized_data, *args[1:], **kwargs) if new_data.dtype != ggml_tensor.quantized_data.dtype: # This is intended to catch calls such as `.to(dtype-torch.float32)`, which are not supported on GGMLTensors. raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.") return GGMLTensor( new_data, ggml_tensor._ggml_quantization_type, ggml_tensor.tensor_shape, ggml_tensor.compute_dtype ) GGML_TENSOR_OP_TABLE = { # Ops to run on the quantized tensor. torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore # Ops to run on dequantized tensors. torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore torch.ops.aten.split.Tensor: dequantize_and_run, # pyright: ignore } class GGMLTensor(torch.Tensor): """A torch.Tensor sub-class holding a quantized GGML tensor. The underlying tensor is quantized, but the GGMLTensor class provides a dequantized view of the tensor on-the-fly when it is used in operations. """ @staticmethod def __new__( cls, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size, compute_dtype: torch.dtype, ): # Type hinting is not supported for torch.Tensor._make_wrapper_subclass, so we ignore the errors. return torch.Tensor._make_wrapper_subclass( # pyright: ignore cls, data.shape, dtype=data.dtype, layout=data.layout, device=data.device, strides=data.stride(), storage_offset=data.storage_offset(), ) def __init__( self, data: torch.Tensor, ggml_quantization_type: gguf.GGMLQuantizationType, tensor_shape: torch.Size, compute_dtype: torch.dtype, ): self.quantized_data = data self._ggml_quantization_type = ggml_quantization_type # The dequantized shape of the tensor. self.tensor_shape = tensor_shape self.compute_dtype = compute_dtype def __repr__(self, *, tensor_contents=None): return f"GGMLTensor(type={self._ggml_quantization_type.name}, dequantized_shape=({self.tensor_shape})" @overload def size(self, dim: None = None) -> torch.Size: ... @overload def size(self, dim: int) -> int: ... def size(self, dim: int | None = None): """Return the size of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" if dim is not None: return self.tensor_shape[dim] return self.tensor_shape @property def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason. """The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops.""" return self.size() @property def quantized_shape(self) -> torch.Size: """The shape of the quantized tensor.""" return self.quantized_data.shape def requires_grad_(self, mode: bool = True) -> torch.Tensor: """The GGMLTensor class is currently only designed for inference (not training). Setting requires_grad to True is not supported. This method is a no-op. """ return self def get_dequantized_tensor(self): """Return the dequantized tensor. Args: dtype: The dtype of the dequantized tensor. """ if self._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: return self.quantized_data.to(self.compute_dtype) elif self._ggml_quantization_type in DEQUANTIZE_FUNCTIONS: return dequantize( data=self.quantized_data, qtype=self._ggml_quantization_type, oshape=self.tensor_shape, dtype=None ).to(self.compute_dtype) else: # There is no GPU implementation for this quantization type, so fallback to the numpy implementation. new = gguf.quants.dequantize(self.quantized_data.cpu().numpy(), self._ggml_quantization_type) return torch.from_numpy(new).to(self.quantized_data.device, dtype=self.compute_dtype) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): # We will likely hit cases here in the future where a new op is encountered that is not yet supported. # The new op simply needs to be added to the GGML_TENSOR_OP_TABLE. if func in GGML_TENSOR_OP_TABLE: return GGML_TENSOR_OP_TABLE[func](func, args, kwargs) else: return dequantize_and_run(func, args, kwargs) return NotImplemented ================================================ FILE: modules/ggml/gguf_utils.py ================================================ # Original: invokeai.backend.quantization.gguf.utils # Largely based on https://github.com/city96/ComfyUI-GGUF from typing import Callable, Optional, Union import gguf import torch TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} # K Quants # QK_K = 256 K_SCALE_SIZE = 12 def get_scale_min(scales: torch.Tensor): n_blocks = scales.shape[0] scales = scales.view(torch.uint8) scales = scales.reshape((n_blocks, 3, 4)) d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) # Legacy Quants # def dequantize_blocks_Q8_0( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: d, x = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) x = x.view(torch.int8) return d * x def dequantize_blocks_Q5_1( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape((n_blocks, -1)) qs = ql | (qh << 4) return (d * qs) + m def dequantize_blocks_Q5_0( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape(n_blocks, -1) qs = (ql | (qh << 4)).to(torch.int8) - 16 return d * qs def dequantize_blocks_Q4_1( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, m, qs = split_block_dims(blocks, 2, 2) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape(1, 1, 2, 1) qs = (qs & 0x0F).reshape(n_blocks, -1) return (d * qs) + m def dequantize_blocks_Q4_0( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, qs = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( [0, 4], device=d.device, dtype=torch.uint8 ).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 return d * qs def dequantize_blocks_BF16( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) def dequantize_blocks_Q6_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] ( ql, qh, scales, d, ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) scales = scales.view(torch.int8).to(dtype) d = d.view(torch.float16).to(dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( (1, 1, 4, 1) ) qh = (qh & 0x03).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)).to(torch.int8) - 32 q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q5_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape( (1, 1, 8, 1) ) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = (qh & 0x01).reshape((n_blocks, -1, 32)) q = ql | (qh << 4) return (d * q - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q4_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 1, 2, 1) ) qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q3_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) d = d.view(torch.float16).to(dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( (1, 2, 1) ) lscales = lscales.reshape((n_blocks, 16)) hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( [0, 2, 4, 6], device=d.device, dtype=torch.uint8 ).reshape((1, 4, 1)) hscales = hscales.reshape((n_blocks, 16)) scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) scales = scales.to(torch.int8) - 32 dl = (d * scales).reshape((n_blocks, 16, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( (1, 1, 4, 1) ) qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor(list(range(8)), device=d.device, dtype=torch.uint8).reshape( (1, 1, 8, 1) ) ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 q = ql.to(torch.int8) - (qh << 2).to(torch.int8) return (dl * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q2_K( blocks: torch.Tensor, block_size: int, type_size: int, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: n_blocks = blocks.shape[0] scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 qs = qs.reshape((n_blocks, QK_K // 16, 16)) qs = dl * qs - ml return qs.reshape((n_blocks, -1)) DEQUANTIZE_FUNCTIONS: dict[ gguf.GGMLQuantizationType, Callable[[torch.Tensor, int, int, Optional[torch.dtype]], torch.Tensor] ] = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, } def is_torch_compatible(tensor: Optional[torch.Tensor]): return getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES def is_quantized(tensor: torch.Tensor): return not is_torch_compatible(tensor) def dequantize( data: torch.Tensor, qtype: gguf.GGMLQuantizationType, oshape: torch.Size, dtype: Optional[torch.dtype] = None ): """ Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] dequantize_blocks = DEQUANTIZE_FUNCTIONS[qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) n_blocks = rows.numel() // type_size blocks = rows.reshape((n_blocks, type_size)) blocks = dequantize_blocks(blocks, block_size, type_size, dtype) return blocks.reshape(oshape) def to_uint32(x: torch.Tensor) -> torch.Tensor: x = x.view(torch.uint8).to(torch.int32) return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) def split_block_dims(blocks: torch.Tensor, *args): n_max = blocks.shape[1] dims = list(args) + [n_max - sum(args)] return torch.split(blocks, dims, dim=1) PATCH_TYPES = Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] ================================================ FILE: modules/gr_hijack.py ================================================ import time from PIL import Image import gradio as gr import gradio.processing_utils from modules import scripts_manager, patches, gr_tempdir hijacked = False original_IOComponent_init = None original_Block_get_config = None original_BlockContext_init = None original_Blocks_get_config_file = None def process_kanvas(self, x): # only used when kanvas overrides gr.Image object import numpy as np from modules import errors t0 = time.time() image_data = list(x.get('image', {}).values()) image = None mask = None if image_data: width = x['imageWidth'] height = x['imageHeight'] array = np.array(image_data, dtype=np.uint8).reshape((height, width, 4)) image = Image.fromarray(array, 'RGBA') image = image.convert('RGB') mask_data = list(x.get('mask', {}).values()) if mask_data: width = x['maskWidth'] height = x['maskHeight'] array = np.array(mask_data, dtype=np.uint8).reshape((height, width, 4)) mask = Image.fromarray(array, 'RGBA') # alpha = mask.getchannel("A").convert("L") # mask = Image.merge("RGB", [alpha, alpha, alpha]) mask = mask.convert('L') t1 = time.time() errors.log.debug(f'Kanvas: image={image} mask={mask} time={t1-t0:.2f}') if image is None: return None if mask is None: return self._format_image(image) # pylint: disable=protected-access return { "image": self._format_image(image), "mask": self._format_image(mask) } # pylint: disable=protected-access def gr_image_preprocess(self, x): if x is None: return x mask = None if isinstance(x, dict) and "kanvas" in x: return process_kanvas(self, x) if isinstance(x, dict) and "image" in x: x, mask = x["image"], x["mask"] if isinstance(x, str): im = gradio.processing_utils.decode_base64_to_image(x) else: im = x im = im.convert(self.image_mode) if self.shape is not None: im = gradio.processing_utils.resize_and_crop(im, self.shape) if self.tool == "sketch" and self.source in ["upload"]: if mask is not None: mask_im = gradio.processing_utils.decode_base64_to_image(mask) if mask_im.mode == "RGBA": # whiten any opaque pixels in the mask alpha_data = mask_im.getchannel("A").convert("L") mask_im = Image.merge("RGB", [alpha_data, alpha_data, alpha_data]) else: mask_im = Image.new("L", im.size, 0) return { "image": self._format_image(im), "mask": self._format_image(mask_im) } # pylint: disable=protected-access return self._format_image(im) # pylint: disable=protected-access def add_classes_to_gradio_component(comp): """ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others """ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] if getattr(comp, 'multiselect', False): comp.elem_classes.append('multiselect') def IOComponent_init(self, *args, **kwargs): self.webui_tooltip = kwargs.pop('tooltip', None) if scripts_manager.scripts_current is not None: scripts_manager.scripts_current.before_component(self, **kwargs) scripts_manager.script_callbacks.before_component_callback(self, **kwargs) res = original_IOComponent_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return add_classes_to_gradio_component(self) scripts_manager.script_callbacks.after_component_callback(self, **kwargs) if scripts_manager.scripts_current is not None: scripts_manager.scripts_current.after_component(self, **kwargs) return res def Block_get_config(self): config = original_Block_get_config(self) webui_tooltip = getattr(self, 'webui_tooltip', None) if webui_tooltip: config["webui_tooltip"] = webui_tooltip config.pop('example_inputs', None) return config def BlockContext_init(self, *args, **kwargs): if scripts_manager.scripts_current is not None: scripts_manager.scripts_current.before_component(self, **kwargs) scripts_manager.script_callbacks.before_component_callback(self, **kwargs) res = original_BlockContext_init(self, *args, **kwargs) # pylint: disable=assignment-from-no-return add_classes_to_gradio_component(self) scripts_manager.script_callbacks.after_component_callback(self, **kwargs) if scripts_manager.scripts_current is not None: scripts_manager.scripts_current.after_component(self, **kwargs) return res def Blocks_get_config_file(self, *args, **kwargs): config = original_Blocks_get_config_file(self, *args, **kwargs) for comp_config in config["components"]: if "example_inputs" in comp_config: comp_config["example_inputs"] = {"serialized": []} return config def patch_gradio(): def wrap_gradio_js(fn): def wrapper(*args, js=None, _js=None, **kwargs): if _js is not None: js = _js return fn(*args, js=js, **kwargs) return wrapper gradio.components.Button.click = wrap_gradio_js(gradio.components.Button.click) gradio.components.Textbox.submit = wrap_gradio_js(gradio.components.Textbox.submit) gradio.components.Image.clear = wrap_gradio_js(gradio.components.Image.clear) gradio.components.Image.change = wrap_gradio_js(gradio.components.Image.change) gradio.components.Image.upload = wrap_gradio_js(gradio.components.Image.upload) gradio.components.Video.change = wrap_gradio_js(gradio.components.Video.change) gradio.components.Video.clear = wrap_gradio_js(gradio.components.Video.clear) gradio.components.Slider.change = wrap_gradio_js(gradio.components.Slider.change) gradio.components.Dropdown.change = wrap_gradio_js(gradio.components.Dropdown.change) gradio.components.File.change = wrap_gradio_js(gradio.components.File.change) gradio.components.File.clear = wrap_gradio_js(gradio.components.File.clear) gradio.components.Number.change = wrap_gradio_js(gradio.components.Number.change) gradio.components.Textbox.change = wrap_gradio_js(gradio.components.Textbox.change) gradio.components.Radio.change = wrap_gradio_js(gradio.components.Radio.change) gradio.components.Checkbox.change = wrap_gradio_js(gradio.components.Checkbox.change) gradio.components.CheckboxGroup.change = wrap_gradio_js(gradio.components.CheckboxGroup.change) gradio.components.ColorPicker.change = wrap_gradio_js(gradio.components.ColorPicker.change) gradio.layouts.Tab.select = wrap_gradio_js(gradio.layouts.Tab.select) gradio.components.Image.edit = lambda *args, **kwargs: None # gradio.components.image.Image.__init__ missing tool, brush_radius, mask_opacity, edit() def init(): global hijacked, original_IOComponent_init, original_Block_get_config, original_BlockContext_init, original_Blocks_get_config_file # pylint: disable=global-statement if hijacked: return gr.components.Image.preprocess = gr_image_preprocess if hasattr(gr.components, 'IOComponent'): gr.components.IOComponent.pil_to_temp_file = gr_tempdir.pil_to_temp_file original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) if not gr.__version__.startswith('3.43'): patch_gradio() hijacked = True ================================================ FILE: modules/gr_tempdir.py ================================================ import os import tempfile from collections import namedtuple from pathlib import Path from PIL import Image, PngImagePlugin from modules import shared, errors, paths Savedfile = namedtuple("Savedfile", ["name"]) debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None def register_tmp_file(gradio, filename): if hasattr(gradio, 'temp_file_sets'): gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} def check_tmp_file(gradio, filename): ok = False if hasattr(gradio, 'temp_file_sets'): ok = ok or any(filename in fileset for fileset in gradio.temp_file_sets) # Check resolved output paths (base + specific) base_samples = shared.opts.outdir_samples base_grids = shared.opts.outdir_grids resolved_paths = [ paths.resolve_output_path(base_samples, shared.opts.outdir_txt2img_samples), paths.resolve_output_path(base_samples, shared.opts.outdir_img2img_samples), paths.resolve_output_path(base_samples, shared.opts.outdir_extras_samples), paths.resolve_output_path(base_samples, shared.opts.outdir_control_samples), paths.resolve_output_path(base_samples, shared.opts.outdir_save), paths.resolve_output_path(base_samples, shared.opts.outdir_video), paths.resolve_output_path(base_samples, shared.opts.outdir_init_images), paths.resolve_output_path(base_grids, shared.opts.outdir_txt2img_grids), paths.resolve_output_path(base_grids, shared.opts.outdir_img2img_grids), paths.resolve_output_path(base_grids, shared.opts.outdir_control_grids), ] # Also check base folders directly if set if base_samples: resolved_paths.append(base_samples) if base_grids: resolved_paths.append(base_grids) for path in resolved_paths: if path: try: ok = ok or Path(path).resolve() in Path(filename).resolve().parents except Exception: pass return ok def pil_to_temp_file(self, img: Image, dir: str, format="png") -> str: # pylint: disable=redefined-builtin,unused-argument """ # original gradio implementation bytes_data = gr.processing_utils.encode_pil_to_bytes(img, format) temp_dir = Path(dir) / self.hash_bytes(bytes_data) temp_dir.mkdir(exist_ok=True, parents=True) filename = str(temp_dir / f"image.{format}") img.save(filename, pnginfo=gr.processing_utils.get_pil_metadata(img)) """ folder = dir already_saved_as = getattr(img, 'already_saved_as', None) exists = os.path.isfile(already_saved_as) if already_saved_as is not None else False debug(f'Image lookup: {already_saved_as} exists={exists}') if already_saved_as and exists: register_tmp_file(shared.demo, already_saved_as) file_obj = Savedfile(already_saved_as) name = file_obj.name debug(f'Image registered: {name}') return name if shared.opts.temp_dir != "": folder = shared.opts.temp_dir use_metadata = False metadata = PngImagePlugin.PngInfo() for key, value in img.info.items(): if isinstance(key, str) and isinstance(value, str): metadata.add_text(key, value) use_metadata = True if not os.path.exists(folder): os.makedirs(folder, exist_ok=True) shared.log.debug(f'Created temp folder: path="{folder}"') with tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=folder) as tmp: name = tmp.name img.save(name, pnginfo=(metadata if use_metadata else None)) img.already_saved_as = name size = os.path.getsize(name) shared.log.debug(f'Save temp: image="{name}" width={img.width} height={img.height} size={size}') shared.state.image_history += 1 params = ', '.join([f'{k}: {v}' for k, v in img.info.items()]) params = params[12:] if params.startswith('parameters: ') else params if len(params) > 2: with open(paths.params_path, "w", encoding="utf8") as file: file.write(params) return name # override save to file function so that it also writes PNG info def on_tmpdir_changed(): if shared.opts.temp_dir == "": return register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) def cleanup_tmpdr(): temp_dir = shared.opts.temp_dir if temp_dir == "" or not os.path.isdir(temp_dir): temp_dir = os.path.join(paths.temp_dir, "gradio") shared.log.debug(f'Temp folder: path="{temp_dir}"') if not os.path.isdir(temp_dir): return for root, _dirs, files in os.walk(temp_dir, topdown=False): for name in files: _, extension = os.path.splitext(name) if extension not in {".png", ".jpg", ".webp", ".jxl"}: continue filename = os.path.join(root, name) os.remove(filename) ================================================ FILE: modules/hashes.py ================================================ import hashlib import os.path from rich import progress, errors from installer import log, console from modules.json_helpers import readfile, writefile from modules.paths import data_path cache_filename = os.path.join(data_path, 'data', 'cache.json') cache_data = None progress_ok = True def init_cache(): global cache_data # pylint: disable=global-statement if cache_data is None: cache_data = {} if not os.path.isfile(cache_filename) else readfile(cache_filename, lock=True, as_type="dict") def dump_cache(): writefile(cache_data, cache_filename) def cache(subsection): global cache_data # pylint: disable=global-statement if cache_data is None: cache_data = {} if not os.path.isfile(cache_filename) else readfile(cache_filename, lock=True, as_type="dict") s = cache_data.get(subsection, {}) cache_data[subsection] = s return s def calculate_sha256(filename, quiet=False): global progress_ok # pylint: disable=global-statement hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 if not quiet: if progress_ok: try: with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=console) as f: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) except errors.LiveError: log.warning('Hash: attempting to use function in a thread') progress_ok = False if not progress_ok: with open(filename, 'rb') as f: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) else: with open(filename, 'rb') as f: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() def sha256_from_cache(filename, title, use_addnet_hash=False): hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") if title not in hashes: return None cached_sha256 = hashes[title].get("sha256", None) cached_mtime = hashes[title].get("mtime", 0) ondisk_mtime = os.path.getmtime(filename) if os.path.isfile(filename) else 0 if ondisk_mtime > cached_mtime or cached_sha256 is None: return None return cached_sha256 def sha256(filename, title, use_addnet_hash=False): from modules import shared global progress_ok # pylint: disable=global-statement hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") sha256_value = sha256_from_cache(filename, title, use_addnet_hash) if sha256_value is not None: return sha256_value if shared.cmd_opts.no_hashing: return None if not os.path.isfile(filename): return None jobid = shared.state.begin("Hash") if use_addnet_hash: if progress_ok: try: with progress.open(filename, 'rb', description=f'[cyan]Calculating hash: [yellow]{filename}', auto_refresh=True, console=shared.console) as f: sha256_value = addnet_hash_safetensors(f) except errors.LiveError: log.warning('Hash: attempting to use function in a thread') progress_ok = False if not progress_ok: with open(filename, 'rb') as f: sha256_value = addnet_hash_safetensors(f) else: sha256_value = calculate_sha256(filename) hashes[title] = { "mtime": os.path.getmtime(filename), "sha256": sha256_value } shared.state.end(jobid) dump_cache() return sha256_value def addnet_hash_safetensors(b): """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py""" hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 b.seek(0) header = b.read(8) n = int.from_bytes(header, "little") offset = n + 8 b.seek(offset) for chunk in iter(lambda: b.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() ================================================ FILE: modules/hidiffusion/__init__.py ================================================ # Original: https://github.com/megvii-research/HiDiffusion import time from modules import shared from modules.hidiffusion import hidiffusion def apply(p, model_type): if model_type not in ['sd', 'sdxl'] and p.hidiffusion: shared.log.warning(f'HiDiffusion: class={shared.sd_model.__class__.__name__} not supported') return unapply() pipe = shared.sd_model.pipe if hasattr(shared.sd_model, 'pipe') else shared.sd_model if getattr(p, 'hidiffusion', False) is True: t0 = time.time() hidiffusion.is_aggressive_raunet = shared.opts.hidiffusion_steps > 0 hidiffusion.aggressive_step = shared.opts.hidiffusion_steps if shared.opts.hidiffusion_t1 >= 0: t1 = shared.opts.hidiffusion_t1 hidiffusion.switching_threshold_ratio_dict['sd15_1024']['T1_ratio'] = t1 hidiffusion.switching_threshold_ratio_dict['sd15_2048']['T1_ratio'] = t1 hidiffusion.switching_threshold_ratio_dict['sdxl_2048']['T1_ratio'] = t1 hidiffusion.switching_threshold_ratio_dict['sdxl_4096']['T1_ratio'] = t1 hidiffusion.switching_threshold_ratio_dict['sdxl_turbo_1024']['T1_ratio'] = t1 p.extra_generation_params['HiDiffusion Ratios'] = f'{shared.opts.hidiffusion_t1}/{shared.opts.hidiffusion_t2}' if shared.opts.hidiffusion_t2 >= 0: t2 =shared.opts.hidiffusion_t2 hidiffusion.switching_threshold_ratio_dict['sd15_1024']['T2_ratio'] = t2 hidiffusion.switching_threshold_ratio_dict['sd15_2048']['T2_ratio'] = t2 hidiffusion.switching_threshold_ratio_dict['sdxl_2048']['T2_ratio'] = t2 hidiffusion.switching_threshold_ratio_dict['sdxl_4096']['T2_ratio'] = t2 hidiffusion.switching_threshold_ratio_dict['sdxl_turbo_1024']['T2_ratio'] = t2 p.extra_generation_params['HiDiffusion Ratios'] = f'{shared.opts.hidiffusion_t1}/{shared.opts.hidiffusion_t2}' hidiffusion.apply_hidiffusion(pipe, apply_raunet=shared.opts.hidiffusion_raunet, apply_window_attn=shared.opts.hidiffusion_attn, model_type=model_type, steps=p.steps) p.extra_generation_params['HiDiffusion'] = f'{shared.opts.hidiffusion_raunet}/{shared.opts.hidiffusion_attn}/{shared.opts.hidiffusion_steps > 0}:{shared.opts.hidiffusion_steps}' t1 = time.time() shared.log.debug(f'Applying HiDiffusion: raunet={shared.opts.hidiffusion_raunet} attn={shared.opts.hidiffusion_attn} aggressive={shared.opts.hidiffusion_steps > 0}:{shared.opts.hidiffusion_steps} t1={shared.opts.hidiffusion_t1} t2={shared.opts.hidiffusion_t2} time={t1-t0:.2f} type={shared.sd_model_type} width={p.width} height={p.height}') elif hasattr(pipe, 'unet') and getattr(pipe.unet, 'hidiffusion', False): shared.log.warning('HiDiffusion: model reload recomended') def unapply(): pipe = shared.sd_model.pipe if hasattr(shared.sd_model, 'pipe') else shared.sd_model if hasattr(pipe, 'unet') and pipe.unet is not None: hidiffusion.remove_hidiffusion(pipe) ================================================ FILE: modules/hidiffusion/hidiffusion.py ================================================ from typing import Type, Dict, Any, Tuple, Optional import math import torch import torch.nn.functional as F from diffusers.pipelines import auto_pipeline current_steps = 50 def sd15_hidiffusion_key(): modified_key = {} modified_key['down_module_key'] = ['down_blocks.0.downsamplers.0.conv'] modified_key['down_module_key_extra'] = ['down_blocks.1'] modified_key['up_module_key'] = ['up_blocks.2.upsamplers.0.conv'] modified_key['up_module_key_extra'] = ['up_blocks.2'] modified_key['windown_attn_module_key'] = [ 'down_blocks.0.attentions.0.transformer_blocks.0', 'down_blocks.0.attentions.1.transformer_blocks.0', 'up_blocks.3.attentions.0.transformer_blocks.0', 'up_blocks.3.attentions.1.transformer_blocks.0', 'up_blocks.3.attentions.2.transformer_blocks.0'] return modified_key def sdxl_hidiffusion_key(): modified_key = {} modified_key['down_module_key'] = ['down_blocks.1'] modified_key['down_module_key_extra'] = ['down_blocks.1.downsamplers.0.conv'] modified_key['up_module_key'] = ['up_blocks.1'] modified_key['up_module_key_extra'] = ['up_blocks.0.upsamplers.0.conv'] modified_key['windown_attn_module_key'] = [ 'down_blocks.1.attentions.0.transformer_blocks.0', 'down_blocks.1.attentions.0.transformer_blocks.1', 'down_blocks.1.attentions.1.transformer_blocks.0', 'down_blocks.1.attentions.1.transformer_blocks.1', 'up_blocks.1.attentions.0.transformer_blocks.0', 'up_blocks.1.attentions.0.transformer_blocks.1', 'up_blocks.1.attentions.1.transformer_blocks.0', 'up_blocks.1.attentions.1.transformer_blocks.1', 'up_blocks.1.attentions.2.transformer_blocks.0', 'up_blocks.1.attentions.2.transformer_blocks.1'] return modified_key def sdxl_turbo_hidiffusion_key(): modified_key = {} modified_key['down_module_key'] = ['down_blocks.1'] modified_key['up_module_key'] = ['up_blocks.1'] modified_key['windown_attn_module_key'] = [ 'down_blocks.1.attentions.0.transformer_blocks.0', 'down_blocks.1.attentions.0.transformer_blocks.1', 'down_blocks.1.attentions.1.transformer_blocks.0', 'down_blocks.1.attentions.1.transformer_blocks.1', 'up_blocks.1.attentions.0.transformer_blocks.0', 'up_blocks.1.attentions.0.transformer_blocks.1', 'up_blocks.1.attentions.1.transformer_blocks.0', 'up_blocks.1.attentions.1.transformer_blocks.1', 'up_blocks.1.attentions.2.transformer_blocks.0', 'up_blocks.1.attentions.2.transformer_blocks.1'] return modified_key # T1_ratio: see T1 introduced in the main paper. T1 = number_inference_step * T1_ratio. A higher T1_ratio can better mitigate object duplication. We set T1_ratio=0.4 by default. You'd better adjust it to fit your prompt. Only active when apply_raunet=True. # T2_ratio: see T2 introduced in the appendix, used in extreme resolution image generation. T2 = number_inference_step * T2_ratio. A higher T2_ratio can better mitigate object duplication. Only active when apply_raunet=True switching_threshold_ratio_dict = { 'sd15_1024': {'T1_ratio': 0.4, 'T2_ratio': 0.0}, 'sd15_2048': {'T1_ratio': 0.7, 'T2_ratio': 0.3}, 'sdxl_2048': {'T1_ratio': 0.4, 'T2_ratio': 0.0}, 'sdxl_4096': {'T1_ratio': 0.7, 'T2_ratio': 0.3}, 'sdxl_turbo_1024': {'T1_ratio': 0.5, 'T2_ratio': 0.0}, } text_to_img_controlnet_switching_threshold_ratio_dict = { 'sdxl_2048': {'T1_ratio': 0.5, 'T2_ratio': 0.0}, } controlnet_apply_steps_rate = 0.6 is_aggressive_raunet = True aggressive_step = 8 inpainting_is_aggressive_raunet = False playground_is_aggressive_raunet = False def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: # replace global self-attention with MSW-MSA class transformer_block(block_class): # Save for unpatching later _parent = block_class def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: # reference: https://github.com/microsoft/Swin-Transformer def window_partition(x, window_size, shift_size, H, W): B, _N, C = x.shape x = x.view(B,H,W,C) if H % 2 != 0 or W % 2 != 0: from modules.errors import log log.warning('HiDiffusion: The feature size is not divisible by 2') x = F.interpolate(x.permute(0,3,1,2).contiguous(), size=(window_size[0]*2, window_size[1]*2), mode='bicubic').permute(0,2,3,1).contiguous() if type(shift_size) == list or type(shift_size) == tuple: if shift_size[0] > 0: x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) else: if shift_size > 0: x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) x = x.view(B, 2, window_size[0], 2, window_size[1], C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) windows = windows.view(-1, window_size[0] * window_size[1], C) return windows def window_reverse(windows, window_size, H, W, shift_size): B, _N, C = windows.shape windows = windows.view(-1, window_size[0], window_size[1], C) B = int(windows.shape[0] / 4) # 2x2 x = windows.view(B, 2, 2, window_size[0], window_size[1], -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, window_size[0]*2, window_size[1]*2, -1) if type(shift_size) == list or type(shift_size) == tuple: if shift_size[0] > 0: x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) else: if shift_size > 0: x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2)) if H % 2 != 0 or W % 2 != 0: x = F.interpolate(x.permute(0,3,1,2).contiguous(), size=(H, W), mode='bicubic').permute(0,2,3,1).contiguous() x = x.view(B, H*W, C) return x batch_size = hidden_states.shape[0] if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype) elif self.use_layer_norm: norm_hidden_states = self.norm1(hidden_states) elif self.use_ada_layer_norm_continuous: norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.use_ada_layer_norm_single: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa norm_hidden_states = norm_hidden_states.squeeze(1) else: raise ValueError("HiDiffusion: Incorrect norm used") if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) # MSW-MSA rand_num = torch.rand(1) _B, N, _C = hidden_states.shape try: ori_H, ori_W = self.info['size'] except Exception as e: raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e downsample_ratio = round(((ori_H*ori_W) / N)**0.5) H, W = (math.ceil(ori_H/downsample_ratio), math.ceil(ori_W/downsample_ratio)) widow_size = (math.ceil(H/2), math.ceil(W/2)) if rand_num <= 0.25: shift_size = (0,0) elif rand_num > 0.25 and rand_num <= 0.5: shift_size = (widow_size[0]//4, widow_size[1]//4) elif rand_num > 0.5 and rand_num <= 0.75: shift_size = (widow_size[0]//4*2, widow_size[1]//4*2) elif rand_num > 0.75 and rand_num <= 1: shift_size = (widow_size[0]//4*3, widow_size[1]//4*3) else: shift_size = (0,0) norm_hidden_states = window_partition(norm_hidden_states, widow_size, shift_size, H, W) # 1. Retrieve lora scale. # cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 2. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output elif self.use_ada_layer_norm_single: attn_output = gate_msa * attn_output attn_output = window_reverse(attn_output, widow_size, H, W, shift_size) hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) # 3. Cross-Attention if self.attn2 is not None: if self.use_ada_layer_norm: norm_hidden_states = self.norm2(hidden_states, timestep) elif self.use_ada_layer_norm_zero or self.use_layer_norm: norm_hidden_states = self.norm2(hidden_states) elif self.use_ada_layer_norm_single: norm_hidden_states = hidden_states elif self.use_ada_layer_norm_continuous: norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) else: raise ValueError("HiDiffusion: Incorrect norm") if self.pos_embed is not None and self.use_ada_layer_norm_single is False: norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 4. Feed-forward if self.use_ada_layer_norm_continuous: norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif not self.use_ada_layer_norm_single: norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.use_ada_layer_norm_single: norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output elif self.use_ada_layer_norm_single: ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states return transformer_block def make_diffusers_cross_attn_down_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: # replace conventional downsampler with resolution-aware downsampler class cross_attn_down_block(block_class): _parent = block_class # Save for unpatching later timestep = 0 aggressive_raunet = False T1_ratio = 0 T1_start = 0 T1_end = 0 T1 = 0 # to avoid confict with sdxl-turbo max_timestep = current_steps def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, additional_residuals: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: if not hasattr(self.info['pipeline'], '_num_timesteps'): self.info['pipeline']._num_timesteps = self.max_timestep # pylint: disable=protected-access self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access # self.max_timestep = len(self.info['scheduler'].timesteps) try: ori_H, ori_W = self.info['size'] except Exception as e: raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e if self.model == 'sd15': if ori_H < 256 or ori_W < 256: self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio] elif self.model == 'sdxl': if ori_H < 512 or ori_W < 512: if self.info['text_to_img_controlnet']: self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] if self.info['is_inpainting_task']: self.aggressive_raunet = inpainting_is_aggressive_raunet else: self.aggressive_raunet = is_aggressive_raunet else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio] elif self.model == 'sdxl_turbo': self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio] else: raise RuntimeError('HiDiffusion: unsupported model type') if self.aggressive_raunet: self.T1_start = int(aggressive_step/50 * self.max_timestep) self.T1_end = int(self.max_timestep * self.T1_ratio) self.T1 = 0 # to avoid confict with sdxl-turbo else: self.T1 = int(self.max_timestep * self.T1_ratio) output_states = () _scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: # hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals if i == 0: if self.aggressive_raunet and self.timestep >= self.T1_start and self.timestep < self.T1_end: self.info["upsample_size"] = (hidden_states.shape[2], hidden_states.shape[3]) hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2),ceil_mode=True) elif self.timestep < self.T1: self.info["upsample_size"] = (hidden_states.shape[2], hidden_states.shape[3]) hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2),ceil_mode=True) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) # hidden_states = downsampler(hidden_states, scale=lora_scale) output_states = output_states + (hidden_states,) self.timestep += 1 if self.timestep == self.max_timestep: self.timestep = 0 return hidden_states, output_states return cross_attn_down_block def make_diffusers_cross_attn_up_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: # replace conventional downsampler with resolution-aware downsampler class cross_attn_up_block(block_class): # Save for unpatching later _parent = block_class timestep = 0 aggressive_raunet = False T1_ratio = 0 T1_start = 0 T1_end = 0 T1 = 0 # to avoid confict with sdxl-turbo max_timestep = 50 def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: def fix_scale(first, second): if (first.shape[-1] != second.shape[-1] or first.shape[-2] != second.shape[-2]): rescale = min(second.shape[-2] / first.shape[-2], second.shape[-1] / first.shape[-1]) # log.debug(f"HiDiffusion rescale: {hidden_states.shape} => {res_hidden_states_tuple[0].shape} scale={rescale}") return F.interpolate(first, scale_factor=rescale, mode='bicubic') return first self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access try: ori_H, ori_W = self.info['size'] except Exception as e: raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e if self.model == 'sd15': if ori_H < 256 or ori_W < 256: self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio] elif self.model == 'sdxl': if ori_H < 512 or ori_W < 512: if self.info['text_to_img_controlnet']: self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] if self.info['is_inpainting_task']: self.aggressive_raunet = inpainting_is_aggressive_raunet else: self.aggressive_raunet = is_aggressive_raunet else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio] elif self.model == 'sdxl_turbo': self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio] else: raise RuntimeError('HiDiffusion: unsupported model type') if self.aggressive_raunet: self.T1_start = int(aggressive_step/50 * self.max_timestep) self.T1_end = int(self.max_timestep * self.T1_ratio) self.T1 = 0 # to avoid confict with sdxl-turbo else: self.T1 = int(self.max_timestep * self.T1_ratio) for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = fix_scale(hidden_states, res_hidden_states) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if i == 1: if self.aggressive_raunet and self.timestep >= self.T1_start and self.timestep < self.T1_end: hidden_states = F.interpolate(hidden_states, size=self.info["upsample_size"], mode='bicubic') elif self.timestep < self.T1: hidden_states = F.interpolate(hidden_states, size=self.info["upsample_size"], mode='bicubic') if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) self.timestep += 1 if self.timestep == self.max_timestep: self.timestep = 0 return hidden_states return cross_attn_up_block def make_diffusers_downsampler_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: # replace conventional downsampler with resolution-aware downsampler class downsampler_block(block_class): # Save for unpatching later _parent = block_class T1_ratio = 0 T1 = 0 timestep = 0 aggressive_raunet = False max_timestep = 50 def forward(self, hidden_states: torch.Tensor, scale = 1.0) -> torch.Tensor: # pylint: disable=unused-argument self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access # self.max_timestep = len(self.info['scheduler'].timesteps) try: ori_H, ori_W = self.info['size'] except Exception as e: raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e if self.model == 'sd15': if ori_H < 256 or ori_W < 256: self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio] elif self.model == 'sdxl': if ori_H < 512 or ori_W < 512: if self.info['text_to_img_controlnet']: self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] if self.info['is_inpainting_task']: self.aggressive_raunet = inpainting_is_aggressive_raunet else: self.aggressive_raunet = is_aggressive_raunet else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio] elif self.model == 'sdxl_turbo': self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio] else: raise RuntimeError('HiDiffusion: unsupported model type') if self.aggressive_raunet: self.T1 = int(aggressive_step/50 * self.max_timestep) else: self.T1 = int(self.max_timestep * self.T1_ratio) if self.timestep < self.T1: self.ori_stride = self.stride # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.ori_padding = self.padding # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.ori_dilation = self.dilation # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.stride = (4,4) # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.padding = (2,2) # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.dilation = (2,2) # pylint: disable=access-member-before-definition, attribute-defined-outside-init hidden_states = F.conv2d( hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) if self.timestep < self.T1: self.stride = self.ori_stride # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.padding = self.ori_padding # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.dilation = self.ori_dilation # pylint: disable=access-member-before-definition, attribute-defined-outside-init self.timestep += 1 if self.timestep == self.max_timestep: self.timestep = 0 return hidden_states return downsampler_block def make_diffusers_upsampler_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: # replace conventional upsampler with resolution-aware downsampler class upsampler_block(block_class): # Save for unpatching later _parent = block_class T1_ratio = 0 T1 = 0 timestep = 0 aggressive_raunet = False max_timestep = 50 def forward(self, hidden_states: torch.Tensor, scale = 1.0) -> torch.Tensor: # pylint: disable=unused-argument self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access # self.max_timestep = len(self.info['scheduler'].timesteps) try: ori_H, ori_W = self.info['size'] except Exception as e: raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e if self.model == 'sd15': if ori_H < 256 or ori_W < 256: self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio] elif self.model == 'sdxl': if ori_H < 512 or ori_W < 512: if self.info['text_to_img_controlnet']: self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio] if self.info['is_inpainting_task']: self.aggressive_raunet = inpainting_is_aggressive_raunet else: self.aggressive_raunet = is_aggressive_raunet else: self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio] elif self.model == 'sdxl_turbo': self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio] else: raise RuntimeError('HiDiffusion: unsupported model type') if self.aggressive_raunet: self.T1 = int(aggressive_step/50 * self.max_timestep) else: self.T1 = int(self.max_timestep * self.T1_ratio) self.timestep += 1 if self.timestep == self.max_timestep: self.timestep = 0 return F.conv2d(hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return upsampler_block def hook_diffusion_model(model: torch.nn.Module): """ Adds a forward pre hook to get the image size. This hook can be removed with remove_hidiffusion. """ def hook(module, args): module.info["size"] = (args[0].shape[2], args[0].shape[3]) return None model.info["hooks"].append(model.register_forward_pre_hook(hook)) def apply_hidiffusion( model: torch.nn.Module, apply_raunet: bool = True, apply_window_attn: bool = True, model_type: str = 'None', steps: int = 50): """ model: diffusers model. We support SD 1.5, 2.1, XL, XL Turbo. apply_raunet: whether to apply RAU-Net apply_window_attn: whether to apply MSW-MSA. """ global current_steps # pylint: disable=global-statement current_steps = steps if hasattr(model, 'controlnet') and (model_type == 'sd' or model_type == 'sdxl'): from .hidiffusion_controlnet import make_diffusers_sdxl_contrtolnet_ppl, make_diffusers_unet_2d_condition make_ppl_fn = make_diffusers_sdxl_contrtolnet_ppl model.__class__ = make_ppl_fn(model.__class__) make_block_fn = make_diffusers_unet_2d_condition model.unet.__class__ = make_block_fn(model.unet.__class__) diffusion_model = model.unet if hasattr(model, "unet") else model diffusion_model.num_upsamplers += 12 diffusion_model.info = { 'size': None, 'upsample_size': None, 'hooks': [], 'text_to_img_controlnet': hasattr(model, 'controlnet'), 'is_inpainting_task': model.__class__ in auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING.values(), 'pipeline': model} if model_type == 'sd': modified_key = sd15_hidiffusion_key() for key, module in diffusion_model.named_modules(): if hasattr(module, "_parent"): raise RuntimeError(f'HiDiffusion: key={key} module={module.__class__} already patched') if apply_raunet and key in modified_key['down_module_key']: module.__class__ = make_diffusers_downsampler_block(module.__class__) module.switching_threshold_ratio = 'T1_ratio' if apply_raunet and key in modified_key['down_module_key_extra']: module.__class__ = make_diffusers_cross_attn_down_block(module.__class__) module.switching_threshold_ratio = 'T2_ratio' if apply_raunet and key in modified_key['up_module_key']: module.__class__ = make_diffusers_upsampler_block(module.__class__) module.switching_threshold_ratio = 'T1_ratio' if apply_raunet and key in modified_key['up_module_key_extra']: module.__class__ = make_diffusers_cross_attn_up_block(module.__class__) module.switching_threshold_ratio = 'T2_ratio' if apply_window_attn and key in modified_key['windown_attn_module_key']: module.__class__ = make_diffusers_transformer_block(module.__class__) if hasattr(module, "_parent"): module.model = 'sd15' module.info = diffusion_model.info elif model_type == 'sdxl': modified_key = sdxl_hidiffusion_key() for key, module in diffusion_model.named_modules(): if hasattr(module, "_parent"): raise RuntimeError(f'HiDiffusion: key={key} module={module.__class__} already patched') if apply_raunet and key in modified_key['down_module_key']: module.__class__ = make_diffusers_cross_attn_down_block(module.__class__) module.switching_threshold_ratio = 'T1_ratio' if apply_raunet and key in modified_key['down_module_key_extra']: module.__class__ = make_diffusers_downsampler_block(module.__class__) module.switching_threshold_ratio = 'T2_ratio' if apply_raunet and key in modified_key['up_module_key']: module.__class__ = make_diffusers_cross_attn_up_block(module.__class__) module.switching_threshold_ratio = 'T1_ratio' if apply_raunet and key in modified_key['up_module_key_extra']: module.__class__ = make_diffusers_upsampler_block(module.__class__) module.switching_threshold_ratio = 'T2_ratio' if apply_window_attn and key in modified_key['windown_attn_module_key']: module.__class__ = make_diffusers_transformer_block(module.__class__) if hasattr(module, "_parent"): module.model = 'sdxl' module.info = diffusion_model.info else: raise RuntimeError('HiDiffusion: unsupported model type') model.info = diffusion_model.info model.hidiffusion = True hook_diffusion_model(diffusion_model) def remove_hidiffusion(model: torch.nn.Module): """ Removes hidiffusion from a Diffusion module if it was already patched. """ model = model.unet if hasattr(model, "unet") else model for _, module in model.named_modules(): while hasattr(module, "_parent"): model.hidiffusion = True module.__class__ = module._parent # pylint: disable=protected-access if hasattr(module, "info"): for hook in module.info.get("hooks", []): hook.remove() module.info["hooks"].clear() del module.info ================================================ FILE: modules/hidiffusion/hidiffusion_controlnet.py ================================================ from typing import Dict, Any, Tuple, Callable, Optional, Union, List import torch import torch.nn.functional as F from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.image_processor import PipelineImageInput from diffusers.utils.torch_utils import is_compiled_module from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.models import ControlNetModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput def make_diffusers_unet_2d_condition(block_class): class unet_2d_condition(block_class): # Save for unpatching later _parent = block_class def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: default_overall_up_factor = 2**self.num_upsamplers forward_upsample_size = False upsample_size = None for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: forward_upsample_size = True break if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if self.config.center_input_sample: sample = 2 * sample - 1.0 timesteps = timestep if not torch.is_tensor(timesteps): is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": if "image_embeds" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`") image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`") text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`") time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": if "image_embeds" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`") image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`") image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`") image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`") image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError(f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`") image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) sample = self.conv_in(sample) if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: # sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): _, _, ori_H, ori_W = down_block_res_sample.shape down_block_additional_residual = F.interpolate(down_block_additional_residual, (ori_H, ori_W), mode='bicubic') down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: _, _, ori_H, ori_W = sample.shape mid_block_additional_residual = F.interpolate(mid_block_additional_residual, (ori_H, ori_W), mode='bicubic') sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, # scale=lora_scale, ) # sample = upsample_block( # hidden_states=sample, # temb=emb, # res_hidden_states_tuple=res_samples, # upsample_size=upsample_size, # scale=lora_scale, # ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) return unet_2d_condition def make_diffusers_sdxl_contrtolnet_ppl(block_class): class sdxl_contrtolnet_ppl(block_class): # Save for unpatching later _parent = block_class @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): # convert image to control_image to fit sdxl_controlnet ppl. if control_image is None: control_image = image image = None self.info['text_to_img_controlnet'] = True else: self.info['text_to_img_controlnet'] = False callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct if image is not None: # image-to-image controlnet self.check_inputs( prompt, prompt_2, control_image, strength, num_inference_steps, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, None, None, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) else: # text-to-image controlnet self.check_inputs( prompt, prompt_2, control_image, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, None, None, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare image and controlnet_conditioning_image if image is not None: image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = control_image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images height, width = control_image[0].shape[-2:] else: raise AssertionError else: if isinstance(controlnet, ControlNetModel): control_image = self.prepare_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = control_image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in control_image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) control_image = images height, width = image[0].shape[-2:] else: raise AssertionError # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) if image is not None: timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) else: timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) # 6. Prepare latent variables if image is not None: # image-to-image controlnet latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, True, ) else: # text-to-image controlnet num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # num_channels_latents = self.unet.config.in_channels # shape = (batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) # if isinstance(generator, list) and len(generator) != batch_size: # raise ValueError( # f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" # f" size of {batch_size}. Make sure the batch size matches the length of the generators." # ) # if latents is None: # latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # else: # latents = latents.to(device) # # scale the initial noise by the standard deviation required by the scheduler # latents = latents * self.scheduler.init_noise_sigma # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if image is not None: if isinstance(control_image, list): original_size = original_size or control_image[0].shape[-2:] else: original_size = original_size or control_image.shape[-2:] target_size = target_size or (height, width) if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) else: if isinstance(control_image, list): original_size = original_size or control_image[0].shape[-2:] else: original_size = original_size or control_image.shape[-2:] target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1], } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] if i < controlnet_apply_steps_rate * num_inference_steps: original_h, original_w = (128,128) _, _, model_input_h, model_input_w = control_model_input.shape downsample_factor = max(model_input_h/original_h, model_input_w/original_w) downsample_size = (int(model_input_h//downsample_factor)//8*8, int(model_input_w//downsample_factor)//8*8) # original_pixel_h, original_pixel_w = (1024,1024) # _, _, pixel_h, pixel_w = control_image.shape # downsample_pixel_factor = max(pixel_h/original_pixel_h, pixel_w/original_pixel_w) # downsample_pixel_size = (int(pixel_h//downsample_pixel_factor)//8*8, int(pixel_w//downsample_pixel_factor)//8*8) downsample_pixel_size = [downsample_size[0]*8, downsample_size[1]*8] down_block_res_samples, mid_block_res_sample = self.controlnet( F.interpolate(control_model_input, downsample_size), # control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=F.interpolate(control_image, downsample_pixel_size), # controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual if i < controlnet_apply_steps_rate * num_inference_steps: noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] else: noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=None, mid_block_additional_residual=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if output_type != "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) return sdxl_contrtolnet_ppl ================================================ FILE: modules/hidiffusion/utils.py ================================================ import torch def isinstance_str(x: object, cls_name: str): """ Checks whether x has any class *named* cls_name in its ancestry. Doesn't require access to the class's implementation. Useful for patching! """ for _cls in x.__class__.__mro__: if _cls.__name__ == cls_name: return True return False def init_generator(device: torch.device, fallback: torch.Generator=None): """ Forks the current default random generator given device. """ if device.type == "cpu": return torch.Generator(device="cpu").set_state(torch.get_rng_state()) elif device.type == "cuda": return torch.Generator(device=device).set_state(torch.cuda.get_rng_state()) else: if fallback is None: return init_generator(torch.device("cpu")) else: return fallback ================================================ FILE: modules/history.py ================================================ """ TODO: apply metadata, preview, load/save """ import sys import datetime from collections import deque import torch from modules import shared, devices class Item(): def __init__(self, latent, preview=None, info=None, ops=[]): self.ts = datetime.datetime.now().replace(microsecond=0) self.name = self.ts.strftime('%Y-%m-%d %H:%M:%S') self.latent = latent.detach().clone().to(devices.cpu) self.preview = preview self.info = info self.ops = ops.copy() self.size = sys.getsizeof(self.latent.storage()) class History(): def __init__(self): self.index = -1 self.latents = deque(maxlen=1024) @property def count(self): return len(self.latents) @property def size(self): s = 0 for item in self.latents: s += item.size return s @property def list(self): shared.log.info(f'History: items={self.count}/{shared.opts.latent_history} size={self.size}') return [item.name for item in self.latents] @property def selected(self): if self.index >= 0 and self.index < self.count: current_index = self.index self.index = -1 else: current_index = 0 item = self.latents[current_index] shared.log.debug(f'History get: index={current_index} time={item.ts} shape={list(item.latent.shape)} dtype={item.latent.dtype} count={self.count}') return item.latent.to(devices.device), current_index def find(self, name): for i, item in enumerate(self.latents): if item.name == name: return i return -1 def add(self, latent, preview=None, info=None, ops=[]): shared.state.latent_history += 1 if shared.opts.latent_history == 0: return if torch.is_tensor(latent): item = Item(latent, preview, info, ops) self.latents.appendleft(item) if self.count >= shared.opts.latent_history: self.latents.pop() def clear(self): self.latents.clear() # shared.log.debug(f'History clear: count={self.count}') def load(self): pass def save(self): pass ================================================ FILE: modules/images.py ================================================ import io import re import os import sys import json import queue import random import datetime import threading import numpy as np import piexif import piexif.helper from PIL import Image, PngImagePlugin, ExifTags, ImageDraw from modules import sd_samplers, shared, script_callbacks, errors, paths from modules.images_grid import ( image_grid as image_grid, get_grid_size as get_grid_size, split_grid as split_grid, combine_grid as combine_grid, check_grid_size as check_grid_size, get_font as get_font, draw_grid_annotations as draw_grid_annotations, draw_prompt_matrix as draw_prompt_matrix, GridAnnotation as GridAnnotation, Grid as Grid, ) from modules.images_resize import resize_image as resize_image from modules.images_namegen import ( FilenameGenerator as FilenameGenerator, get_next_sequence_number as get_next_sequence_number, ) from modules.video import save_video as save_video debug = errors.log.trace if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None debug_save = errors.log.trace if os.environ.get('SD_SAVE_DEBUG', None) is not None else lambda *args, **kwargs: None try: from pi_heif import register_heif_opener register_heif_opener() except Exception: pass def sanitize_filename_part(text, replace_spaces=True): if text is None: return None if replace_spaces: text = text.replace(' ', '_') invalid_filename_chars = '#<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' max_filename_part_length = 64 text = text.translate({ord(x): '_' for x in invalid_filename_chars}) text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length] text = text.rstrip(invalid_filename_postfix) return text def atomically_save_image(): Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes while True: image, filename, extension, params, exifinfo, filename_txt, is_grid = save_queue.get() jobid = shared.state.begin('Save image') shared.state.image_history += 1 if len(exifinfo) > 2: with open(paths.params_path, "w", encoding="utf8") as file: file.write(exifinfo) fn = filename + extension filename = filename.strip() if extension[0] != '.': # add dot if missing extension = '.' + extension try: image_format = Image.registered_extensions()[extension] except Exception: shared.log.warning(f'Save: unknown image format: {extension}') image_format = 'JPEG' exifinfo = (exifinfo or "") if shared.opts.image_metadata else "" # additional metadata saved in files if shared.opts.save_txt and len(exifinfo) > 0: try: with open(filename_txt, "w", encoding="utf8") as file: file.write(f"{exifinfo}\n") shared.log.info(f'Save: text="{filename_txt}" len={len(exifinfo)}') except Exception as e: shared.log.warning(f'Save failed: description={filename_txt} {e}') # actual save if image_format == 'PNG': pnginfo_data = PngImagePlugin.PngInfo() for k, v in params.pnginfo.items(): pnginfo_data.add_text(k, str(v)) debug_save(f'Save pnginfo: {params.pnginfo.items()}') save_args = { 'compress_level': 6, 'pnginfo': pnginfo_data if shared.opts.image_metadata else None } elif image_format == 'JPEG': if image.mode == 'RGBA': shared.log.warning('Save: removing alpha channel') image = image.convert("RGB") elif image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("L") save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality } if shared.opts.image_metadata: debug_save(f'Save exif: {exifinfo}') save_args['exif'] = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) elif image_format == 'WEBP': if image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality, 'lossless': shared.opts.webp_lossless } if shared.opts.image_metadata: debug_save(f'Save exif: {exifinfo}') save_args['exif'] = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) elif image_format == 'JXL': if image.mode == 'I;16': image = image.point(lambda p: p * 0.0038910505836576).convert("RGB") elif image.mode not in {"RGB", "RGBA"}: image = image.convert("RGBA") save_args = { 'optimize': True, 'quality': shared.opts.jpeg_quality, 'lossless': shared.opts.webp_lossless } if shared.opts.image_metadata: debug_save(f'Save exif: {exifinfo}') save_args['exif'] = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(exifinfo, encoding="unicode") } }) else: save_args = { 'quality': shared.opts.jpeg_quality } try: debug_save(f'Save args: {save_args}') image.save(fn, format=image_format, **save_args) except Exception as e: shared.log.error(f'Save failed: file="{fn}" format={image_format} args={save_args} {e}') errors.display(e, 'Image save') size = os.path.getsize(fn) if os.path.exists(fn) else 0 what = 'grid' if is_grid else 'image' shared.log.info(f'Save: {what}="{fn}" type={image_format} width={image.width} height={image.height} size={size}') if shared.opts.save_log_fn != '' and len(exifinfo) > 0: fn = os.path.join(paths.data_path, shared.opts.save_log_fn) if not fn.endswith('.json'): fn += '.json' entries = shared.readfile(fn, silent=True) if not isinstance(entries, list): entries = [] idx = len(entries) entry = { 'id': idx, 'filename': filename, 'time': datetime.datetime.now().isoformat(), 'info': exifinfo } entries.append(entry) shared.writefile(entries, fn, mode='w', silent=True) shared.log.info(f'Save: json="{fn}" records={len(entries)}') shared.state.outputs(filename) shared.state.end(jobid) save_queue.task_done() save_queue: queue.Queue[tuple[Image.Image, str, str, script_callbacks.ImageSaveParams, str, str | None, bool]] = queue.Queue() save_thread = threading.Thread(target=atomically_save_image, daemon=True) save_thread.start() def save_image(image, path=None, basename='', seed=None, prompt=None, extension=shared.opts.samples_format, info=None, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix='', save_to_dirs=None, ): fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access debug_save(f'Save: fn={fn}') # pylint: disable=protected-access if image is None: shared.log.warning('Image is none') return None, None, None if isinstance(image, list): if len(image) > 1: shared.log.warning(f'Save: images={image} multiple images provided only the first one will be saved') image = image[0] if not check_grid_size([image]): return None, None, None if path is None or path == '': # set default path to avoid errors when functions are triggered manually or via api and param is not set path = paths.resolve_output_path(shared.opts.outdir_samples, shared.opts.outdir_save) namegen = FilenameGenerator(p, seed, prompt, image, grid=grid) suffix = suffix if suffix is not None else '' basename = '' if basename is None else basename if save_to_dirs is not None and isinstance(save_to_dirs, str) and len(save_to_dirs) > 0: dirname = save_to_dirs path = os.path.join(path, dirname) elif shared.opts.save_to_dirs: dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]") path = os.path.join(path, dirname) if forced_filename is None: if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0: file_decoration = shared.opts.samples_filename_pattern else: file_decoration = "[seq]-[prompt_words]" file_decoration = namegen.apply(file_decoration) file_decoration += suffix if file_decoration.startswith(basename): basename = '' filename = os.path.join(path, f"{file_decoration}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{file_decoration}.{extension}") else: forced_filename += suffix if forced_filename.startswith(basename): basename = '' filename = os.path.join(path, f"{forced_filename}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{forced_filename}.{extension}") pnginfo = existing_info or {} if info is None: info = image.info.get(pnginfo_section_name, '') if info is not None: pnginfo[pnginfo_section_name] = info wm_text = getattr(p, 'watermark_text', shared.opts.image_watermark) wm_image = getattr(p, 'watermark_image', shared.opts.image_watermark_image) image = set_watermark(image, wm_text, wm_image) params = script_callbacks.ImageSaveParams(image, p, filename, pnginfo) params.filename = namegen.sanitize(filename) dirname = os.path.dirname(params.filename) if dirname is not None and len(dirname) > 0: os.makedirs(dirname, exist_ok=True) params.filename = namegen.sequence(params.filename) params.filename = namegen.sanitize(params.filename) # callbacks script_callbacks.before_image_saved_callback(params) exifinfo = params.pnginfo.get('UserComment', '') exifinfo = exifinfo + ', ' if len(exifinfo) > 0 else '' exifinfo += params.pnginfo.get(pnginfo_section_name, '') filename, extension = os.path.splitext(params.filename) filename_txt = f"{filename}.txt" if shared.opts.save_txt and len(exifinfo) > 0 else None save_queue.put((params.image, filename, extension, params, exifinfo, filename_txt, grid)) # actual save is executed in a thread that polls data from queue save_queue.join() if not hasattr(params.image, 'already_saved_as'): debug(f'Image marked: "{params.filename}"') params.image.already_saved_as = params.filename script_callbacks.image_saved_callback(params) return params.filename, filename_txt, exifinfo def safe_decode_string(s: bytes): remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment for encoding in ['utf_16_be', 'utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings try: s = remove_prefix(s, b'UNICODE') s = remove_prefix(s, b'ASCII') s = remove_prefix(s, b'\x00') val = s.decode(encoding, errors="strict") val = re.sub(r'[\x00-\x09]', '', val).strip() # remove remaining special characters if len(val) == 0: # remove empty strings val = None return val except Exception: pass return None def parse_comfy_metadata(data: dict): def parse_workflow(): res = '' try: txt = data.get('workflow', {}) dct = json.loads(txt) nodes = len(dct.get('nodes', [])) version = dct.get('extra', {}).get('frontendVersion', 'unknown') if version is not None: res = f" | Version: {version} | Nodes: {nodes}" except Exception: pass return res def parse_prompt(): res = '' try: txt = data.get('prompt', {}) dct = json.loads(txt) for val in dct.values(): inp = val.get('inputs', {}) if 'model' in inp: model = inp.get('model', None) if isinstance(model, str) and len(model) > 0: res += f" | Model: {model} | Class: {val.get('class_type', '')}" except Exception: pass return res workflow = parse_workflow() prompt = parse_prompt() if len(workflow) > 0 or len(prompt) > 0: parsed = f'App: ComfyUI{workflow}{prompt}' shared.log.info(f'Image metadata: {parsed}') return parsed return '' def parse_invoke_metadata(data: dict): def parse_metadtaa(): res = '' try: txt = data.get('invokeai_metadata', {}) dct = json.loads(txt) if 'app_version' in dct: version = dct['app_version'] if isinstance(version, str) and len(version) > 0: res += f" | Version: {version}" except Exception: pass return res metadata = parse_metadtaa() if len(metadata) > 0: parsed = f'App: InvokeAI{metadata}' shared.log.info(f'Image metadata: {parsed}') return parsed return '' def parse_novelai_metadata(data: dict): geninfo = '' if data.get("Software", None) == "NovelAI": try: dct = json.loads(data["Comment"]) sampler = sd_samplers.samplers_map.get(dct["sampler"], "Euler a") geninfo = f'{data["Description"]} Negative prompt: {dct["uc"]} Steps: {dct["steps"]}, Sampler: {sampler}, CFG scale: {dct["scale"]}, Seed: {dct["seed"]}, Clip skip: 2, ENSD: 31337' except Exception: pass return geninfo def read_info_from_image(image: Image.Image, watermark: bool = False) -> tuple[str, dict]: if image is None: return '', {} if isinstance(image, str): try: image = Image.open(image) image.load() except Exception: return '', {} items = image.info or {} geninfo = items.pop('parameters', None) or items.pop('UserComment', None) or '' if isinstance(geninfo, dict): if 'UserComment' in geninfo: geninfo = geninfo['UserComment'] # Info was nested else: geninfo = '' # Unknown format. Ignore contents items['UserComment'] = geninfo if "exif" in items: try: exif = piexif.load(items["exif"]) except Exception as e: shared.log.error(f'Error loading EXIF data: {e}') exif = {} for _key, subkey in exif.items(): if isinstance(subkey, dict): for key, val in subkey.items(): if isinstance(val, bytes): # decode bytestring val = safe_decode_string(val) if isinstance(val, tuple) and isinstance(val[0], int) and isinstance(val[1], int) and val[1] > 0: # convert camera ratios val = round(val[0] / val[1], 2) if val is not None and key in ExifTags.TAGS: # add known tags if ExifTags.TAGS[key] == 'UserComment': # add geninfo from UserComment geninfo = str(val) items['parameters'] = val else: items[ExifTags.TAGS[key]] = val elif val is not None and key in ExifTags.GPSTAGS: items[ExifTags.GPSTAGS[key]] = val if watermark: wm = get_watermark(image) if wm != '': # geninfo += f' Watermark: {wm}' items['watermark'] = wm for key, val in items.items(): if isinstance(val, bytes): # decode bytestring items[key] = safe_decode_string(val) geninfo += parse_comfy_metadata(items) geninfo += parse_invoke_metadata(items) geninfo += parse_novelai_metadata(items) for key in ['exif', 'ExifOffset', 'JpegIFOffset', 'JpegIFByteCount', 'ExifVersion', 'icc_profile', 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'adobe', 'photoshop', 'loop', 'duration', 'dpi']: # remove unwanted tags items.pop(key, None) try: items['width'] = image.width items['height'] = image.height items['mode'] = image.mode except Exception: pass return geninfo, items def image_data(data): import gradio as gr if data is None: return gr.update(), None err1 = None err2 = None try: image = Image.open(io.BytesIO(data)) image.load() info, _ = read_info_from_image(image) errors.log.debug(f'Decoded object: image={image} metadata={info}') return info, None except Exception as e: err1 = e try: if len(data) > 1024 * 10: errors.log.warning(f'Error decoding object: data too long: {len(data)}') return gr.update(), None info = data.decode('utf8') errors.log.debug(f'Decoded object: data={len(data)} metadata={info}') return info, None except Exception as e: err2 = e errors.log.error(f'Error decoding object: {err1 or err2}') return gr.update(), None def flatten(img, bgcolor): """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency""" if img.mode == "RGBA": background = Image.new('RGBA', img.size, bgcolor) background.paste(img, mask=img) img = background return img.convert('RGB') def draw_overlay(im, text: str = '', y_offset: int = 0): d = ImageDraw.Draw(im) fontsize = (im.width + im.height) // 50 font = get_font(fontsize) d.text((fontsize//2, fontsize//2 + y_offset), text, font=font, fill=shared.opts.font_color) return im def set_watermark(image, wm_text: str | None = None, wm_image: Image.Image | None = None): if shared.opts.image_watermark_position != 'none' and wm_image is not None: # visible watermark if isinstance(wm_image, str): try: wm_image = Image.open(wm_image) except Exception as e: shared.log.warning(f'Set image watermark: image={wm_image} {e}') return image if isinstance(wm_image, Image.Image): if wm_image.mode != 'RGBA': wm_image = wm_image.convert('RGBA') if shared.opts.image_watermark_position == 'top/left': position = (0, 0) elif shared.opts.image_watermark_position == 'top/right': position = (image.width - wm_image.width, 0) elif shared.opts.image_watermark_position == 'bottom/left': position = (0, image.height - wm_image.height) elif shared.opts.image_watermark_position == 'bottom/right': position = (image.width - wm_image.width, image.height - wm_image.height) elif shared.opts.image_watermark_position == 'center': position = ((image.width - wm_image.width) // 2, (image.height - wm_image.height) // 2) else: position = (random.randint(0, image.width - wm_image.width), random.randint(0, image.height - wm_image.height)) try: for x in range(wm_image.width): for y in range(wm_image.height): rgba = wm_image.getpixel((x, y)) orig = image.getpixel((x+position[0], y+position[1])) # alpha blend a = rgba[3] / 255 r = int(rgba[0] * a + orig[0] * (1 - a)) g = int(rgba[1] * a + orig[1] * (1 - a)) b = int(rgba[2] * a + orig[2] * (1 - a)) if not a == 0: image.putpixel((x+position[0], y+position[1]), (r, g, b)) shared.log.debug(f'Set image watermark: image={wm_image} position={position}') except Exception as e: shared.log.warning(f'Set image watermark: image={wm_image} {e}') if shared.opts.image_watermark_enabled and wm_text is not None: # invisible watermark from imwatermark import WatermarkEncoder wm_type = 'bytes' wm_method = 'dwtDctSvd' wm_length = 32 length = wm_length // 8 info = image.info data = np.asarray(image) encoder = WatermarkEncoder() text = f"{wm_text:<{length}}"[:length] bytearr = text.encode(encoding='ascii', errors='ignore') try: encoder.set_watermark(wm_type, bytearr) encoded = encoder.encode(data, wm_method) image = Image.fromarray(encoded) image.info = info shared.log.debug(f'Set invisible watermark: {wm_text} method={wm_method} bits={wm_length}') except Exception as e: shared.log.warning(f'Set invisible watermark error: {wm_text} method={wm_method} bits={wm_length} {e}') return image def get_watermark(image): from imwatermark import WatermarkDecoder wm_type = 'bytes' wm_method = 'dwtDctSvd' wm_length = 32 data = np.asarray(image) decoder = WatermarkDecoder(wm_type, wm_length) try: decoded = decoder.decode(data, wm_method) wm = decoded.decode(encoding='ascii', errors='ignore') except Exception: wm = '' return wm ================================================ FILE: modules/images_grid.py ================================================ import math from collections import namedtuple import numpy as np from PIL import Image, ImageFont, ImageDraw from modules import shared, script_callbacks Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) def check_grid_size(imgs): if imgs is None or len(imgs) == 0: return False mp = 0 for img in imgs: if isinstance(img, list): for im in img: mp += im.width * im.height if im is not None else 0 else: mp += img.width * img.height if img is not None else 0 mp = round(mp / 1000000) ok = mp <= shared.opts.img_max_size_mp if not ok: shared.log.warning(f'Maximum image size exceded: size={mp} maximum={shared.opts.img_max_size_mp} MPixels') return ok def get_grid_size(imgs, batch_size=1, rows=None, cols=None): if rows and rows > len(imgs): rows = len(imgs) if cols and cols > len(imgs): cols = len(imgs) if rows is None and cols is None: if shared.opts.n_rows > 0: rows = shared.opts.n_rows cols = math.ceil(len(imgs) / rows) elif shared.opts.n_rows == 0: rows = batch_size cols = math.ceil(len(imgs) / rows) elif shared.opts.n_cols > 0: cols = shared.opts.n_cols rows = math.ceil(len(imgs) / cols) elif shared.opts.n_cols == 0: cols = batch_size rows = math.ceil(len(imgs) / cols) else: rows = math.floor(math.sqrt(len(imgs))) while len(imgs) % rows != 0: rows -= 1 cols = math.ceil(len(imgs) / rows) elif cols is None: cols = math.ceil(len(imgs) / rows) elif rows is None: rows = math.ceil(len(imgs) / cols) else: pass return rows, cols def image_grid(imgs, batch_size:int=1, rows:int=None, cols:int=None): rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols) params = script_callbacks.ImageGridLoopParams(imgs, cols, rows) script_callbacks.image_grid_callback(params) imgs = [i for i in imgs if i is not None] if imgs is not None else [] if len(imgs) == 0: return None w, h = max(i.width for i in imgs if i is not None), max(i.height for i in imgs if i is not None) grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=shared.opts.grid_background) for i, img in enumerate(params.imgs): if img is not None: grid.paste(img, box=(i % params.cols * w, i // params.cols * h)) return grid def split_grid(image, tile_w=512, tile_h=512, overlap=64): w = image.width h = image.height non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap cols = math.ceil((w - overlap) / non_overlap_width) rows = math.ceil((h - overlap) / non_overlap_height) dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): row_images = [] y = int(row * dy) if y + tile_h >= h: y = h - tile_h for col in range(cols): x = int(col * dx) if x + tile_w >= w: x = w - tile_w tile = image.crop((x, y, x + tile_w, y + tile_h)) row_images.append([x, tile_w, tile]) grid.tiles.append([y, tile_h, row_images]) return grid def combine_grid(grid): def make_mask_image(r): r = r * 255 / grid.overlap r = r.astype(np.uint8) return Image.fromarray(r, 'L') mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) for y, h, row in grid.tiles: combined_row = Image.new("RGB", (grid.image_w, h)) for x, w, tile in row: if x == 0: combined_row.paste(tile, (0, 0)) continue combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) if y == 0: combined_image.paste(combined_row, (0, 0)) continue combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h) combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap)) return combined_image class GridAnnotation: def __init__(self, text='', is_active=True): self.text = str(text) self.is_active = is_active self.size = None def get_font(fontsize): try: return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize) except Exception: return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize) def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=None): def wrap(drawing, text, font, line_length): lines = [''] for word in text.split(): line = f'{lines[-1]} {word}'.strip() if drawing.textlength(line, font=font) <= line_length: lines[-1] = line else: lines.append(word) return lines def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize): for line in lines: font = initial_fnt fontsize = initial_fontsize while drawing.multiline_textbbox((0,0), text=line.text, font=font)[2] > line.allowed_width and fontsize > 0: fontsize -= 1 font = get_font(fontsize) drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=font, fill=shared.opts.font_color if line.is_active else color_inactive, anchor="mm", align="center") if not line.is_active: drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) draw_y += line.size[1] + line_spacing fontsize = (width + height) // 25 line_spacing = fontsize // 2 font = get_font(fontsize) color_inactive = (127, 127, 127) pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in y_texts]) == 0 else width * 3 // 4 cols = len(x_texts) rows = len(y_texts) # assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' # assert rows == len(hor_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background) calc_d = ImageDraw.Draw(calc_img) title_texts = [title] if title else [[GridAnnotation()]] for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols]): items = [] + texts texts.clear() for line in items: wrapped = wrap(calc_d, line.text, font, allowed_width) texts += [GridAnnotation(x, line.is_active) for x in wrapped] for line in texts: bbox = calc_d.multiline_textbbox((0, 0), line.text, font=font) line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) line.allowed_width = allowed_width hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in x_texts] ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in y_texts] pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 title_pad = 0 if title: title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2 result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background) for row in range(rows): for col in range(cols): cell = im.crop((width * col, height * row, width * (col+1), height * (row+1))) result.paste(cell, (pad_left + (width + margin) * col, pad_top + title_pad + (height + margin) * row)) d = ImageDraw.Draw(result) if title: x = pad_left + ((width+margin)*cols) / 2 y = title_pad / 2 - title_text_heights[0] / 2 draw_texts(d, x, y, title_texts[0], font, fontsize) for col in range(cols): x = pad_left + (width + margin) * col + width / 2 y = (pad_top / 2 - hor_text_heights[col] / 2) + title_pad draw_texts(d, x, y, x_texts[col], font, fontsize) for row in range(rows): x = pad_left / 2 y = (pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2) + title_pad draw_texts(d, x, y, y_texts[row], font, fontsize) return result def draw_prompt_matrix(im, width, height, all_prompts, margin=0): prompts = all_prompts[1:] boundary = math.ceil(len(prompts) / 2) prompts_horiz = prompts[:boundary] prompts_vert = prompts[boundary:] hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))] ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))] return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin) ================================================ FILE: modules/images_namegen.py ================================================ import re import os import time import unicodedata import uuid import string import hashlib import datetime from pathlib import Path from modules import shared, errors debug= os.environ.get('SD_NAMEGEN_DEBUG', None) is not None debug_log = errors.log.trace if debug else lambda *args, **kwargs: None re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)") re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") re_attention = re.compile(r'[\(*\[*](\w+)(:\d+(\.\d+))?[\)*\]*]|') re_network = re.compile(r'\<\w+:(\w+)(:\d+(\.\d+))?\>|') re_brackets = re.compile(r'[\([{})\]]') re_leading_seq = re.compile(r'^(0*\d+)(?=[-_.\s]|$)') seq = 0 NOTHING = object() class FilenameGenerator: replacements = { 'width': lambda self: self.width, 'height': lambda self: self.height, 'batch_number': lambda self: self.batch_number, 'iter_number': lambda self: self.iter_number, 'num': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, 'generation_number': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime