[
  {
    "path": ".circleci/config.yml",
    "content": "version: 2.1\n\njobs:\n  build:\n    docker:\n     - image: cimg/python:3.9-node\n    resource_class: medium\n\n    # would have been nice, but not for $2,000/month!\n    # machine:\n    #   image: ubuntu-2004-cuda-11.4:202110-01\n    # resource_class: gpu.nvidia.small\n\n    steps:\n      - checkout\n\n      - setup_remote_docker:\n          docker_layer_caching: true\n\n      - run: docker build -t gadicc/diffusers-api .\n\n      # unit tests\n      # - run: docker run gadicc/diffusers-api conda run --no-capture -n xformers pytest --cov=. --cov-report=xml --ignore=diffusers\n      - run: docker run gadicc/diffusers-api pytest --cov=. --cov-report=xml --ignore=diffusers --ignore=Real-ESRGAN\n\n      - run: echo $DOCKER_PASSWORD | docker login --username $DOCKER_USERNAME --password-stdin\n\n      # push for non-semver branches (e.g. dev, feature branches)\n      # - run:\n      #     name: Push to hub on branches not handled by semantic-release\n      #     command: |\n      #       SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\\/\\/.*//' | jq .branches[])\n      # \n      #       if [[ ${SEMVER_BRANCHES[@]} =~ \"$CIRCLE_BRANCH\" ]] ; then\n      #         echo \"Skipping because '\\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'\"\n      #         echo \"Semantic-release will handle the publishing\"\n      #       else\n      #         echo \"docker push gadicc/diffusers-api:$CIRCLE_BRANCH\"\n      #         docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH .\n      #         docker push gadicc/diffusers-api:$CIRCLE_BRANCH\n      #         echo \"Skipping integration tests\"\n      #         circleci-agent step halt\n      #       fi\n\n      # needed for later \"apt install\" steps\n      - run: sudo apt-get update\n\n      ## TODO.  The below was a great first step, but in future, let's build\n      # the container on the host, run docker remotely on lambda, and\n      # publish the same built image if tests pass.\n\n      # TODO, only run on main channel for releases (with sem-rel too)\n      # integration tests\n      - run: sudo apt install -yqq rsync pv\n      - run: ./run_integration_tests_on_lambda.sh\n\n      - run:\n          name: Push to hub on branches not handled by semantic-release\n          command: |\n            SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\\/\\/.*//' | jq .branches[])\n\n            if [[ ${SEMVER_BRANCHES[@]} =~ \"$CIRCLE_BRANCH\" ]] ; then\n              echo \"Skipping because '\\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'\"\n              echo \"Semantic-release will handle the publishing\"\n            else\n              echo \"docker push gadicc/diffusers-api:$CIRCLE_BRANCH\"\n              docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH .\n              docker push gadicc/diffusers-api:$CIRCLE_BRANCH\n              # echo \"Skipping integration tests\"\n              # circleci-agent step halt\n            fi\n\n      # deploy the image\n      # - run: docker push company/app:$CIRCLE_BRANCH\n      # https://github.com/semantic-release-plus/semantic-release-plus/tree/master/packages/plugins/docker\n      - run:\n          name: release\n          command: |\n            sudo apt-get install yarn\n            yarn install\n            yarn run semantic-release-plus"
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile\n{\n\t\"name\": \"Existing Dockerfile\",\n\t\"build\": {\n\t\t// Sets the run context to one level up instead of the .devcontainer folder.\n\t\t\"context\": \"..\",\n\t\t// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.\n\t\t\"dockerfile\": \"../Dockerfile\"\n\t},\n\n\t// Features to add to the dev container. More info: https://containers.dev/features.\n\t\"features\": {\n\t\t\"ghcr.io/devcontainers/features/python:1\": {\n\t\t\t// \"version\": \"3.10\"\n\t\t}\n\t},\n\n\t// Use 'forwardPorts' to make a list of ports inside the container available locally.\n\t\"forwardPorts\": [8000],\n\n\t// Uncomment the next line to run commands after the container is created.\n\t\"postCreateCommand\": \"scripts/devContainerPostCreate.sh\",\n\n\t\"customizations\": {\n\t\t\"vscode\": {\n\t\t\t\"extensions\": [\n\t\t\t\t\"ryanluker.vscode-coverage-gutters\",\n\t\t\t\t\"fsevenm.run-it-on\",\n\t\t\t\t\"ms-python.black-formatter\",\n\t\t\t],\n\t\t\t\"settings\": {\n\t\t\t\t\"python.pythonPath\": \"/opt/conda/bin/python\"\n\t\t\t}\n\t\t}\n\t},\n\n\t// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.\n\t// \"remoteUser\": \"devcontainer\"\n\n\t\"mounts\": [\n\t\t\"source=${localEnv:HOME}/root-cache,target=/root/.cache,type=bind,consistency=cached\"\n\t],\n\n\t\"runArgs\": [\n    \"--gpus\",\n    \"all\",\n\t\t\"--env-file\",\n\t\t\".devcontainer/local.env\"\n\t]\n}\n"
  },
  {
    "path": ".devcontainer/local.example.env",
    "content": "# Useful environment variables:\n\n# AWS or S3-compatible storage credentials and buckets\nAWS_ACCESS_KEY_ID=\nAWS_SECRET_ACCESS_KEY=\nAWS_DEFAULT_REGION=\nAWS_S3_DEFAULT_BUCKET=\n# Only fill this in if your (non-AWS) provider has told you what to put here\nAWS_S3_ENDPOINT_URL=\n\n# To use a proxy, e.g.\n# https://github.com/kiri-art/docker-diffusers-api/blob/dev/CONTRIBUTING.md#local-https-caching-proxy\n# DDA_http_proxy=http://172.17.0.1:3128\n# DDA_https_proxy=http://172.17.0.1:3128\n\n# HuggingFace credentials\nHF_AUTH_TOKEN=\nHF_USERNAME=\n\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\n/lib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\npermutations\ntests/output\nnode_modules\n.devcontainer/local.env\n"
  },
  {
    "path": ".vscode/settings.json",
    "content": "{\n  \"python.testing.pytestArgs\": [\n    \"--cov=.\",\n    \"--cov-report=xml\",\n    \"--ignore=test.py\",\n    \"--ignore=tests/integration\",\n    \"--ignore=diffusers\",\n    // \"unit_tests.py\"\n    // \".\"\n  ],\n  \"python.testing.unittestEnabled\": false,\n  \"python.testing.pytestEnabled\": true,\n  // \"python.defaultInterpreterPath\": \"/opt/conda/envs/xformers/bin/python\",\n  \"python.defaultInterpreterPath\": \"/opt/conda/bin/python\",\n  \"runItOn\": {\n    \"commands\": [\n        {\n            \"match\": \"\\\\.py$\",\n            \"isAsync\": true,\n            \"isShellCommand\": false,\n            \"cmd\": \"testing.runAll\"\n        },\n    ],\n  },\n  \"[python]\": {\n    \"editor.defaultFormatter\": \"ms-python.black-formatter\"\n  },\n  \"python.formatting.provider\": \"none\"\n}\n"
  },
  {
    "path": ".vscode/tasks.json",
    "content": "{\n  // See https://go.microsoft.com/fwlink/?LinkId=733558\n  // for the documentation about the tasks.json format\n  \"version\": \"2.0.0\",\n  \"tasks\": [\n      {\n          \"label\": \"Watching Server\",\n          \"type\": \"shell\",\n          \"command\": \"scripts/devContainerServer.sh\"\n      }\n  ]\n}"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# [1.7.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.6.0...v1.7.0) (2023-09-04)\n\n\n### Bug Fixes\n\n* **addons:** async TI download status, LoRA improvements ([de8cfdc](https://github.com/kiri-art/docker-diffusers-api/commit/de8cfdc63d7ae46bed90862fe3bffe65534d3e55))\n* **circleci:** pytest --ignore=Real-ESRGAN ([d7038b5](https://github.com/kiri-art/docker-diffusers-api/commit/d7038b5aa54c8b3dab2149ea773e007b9c0202ce))\n* **circleci:** remove conda from pytest call ([2f29af2](https://github.com/kiri-art/docker-diffusers-api/commit/2f29af2c012ef38ed2e2bc0ec116b59b8c429e57))\n* **diffusers:** bump to aae2726 (jul30) post v0.19.2 + fixes ([6c0a10a](https://github.com/kiri-art/docker-diffusers-api/commit/6c0a10a743abb7cd12cce9bf1cc6a598c6804e92))\n* **Dockerfile:** -yqq for apt-get, apt-utils, extra deps ([bf470da](https://github.com/kiri-art/docker-diffusers-api/commit/bf470dabb9b3c6d7f16d11126ffef0f4ee4806f5))\n* **Dockerfile:** TZ tzdata fix ([9c5d911](https://github.com/kiri-art/docker-diffusers-api/commit/9c5d911aafedc1a2dab94a5c1c1c25aa4bc0ce7a))\n* **misc:** fix failing tests, pipeline init in rare circumstances ([9338648](https://github.com/kiri-art/docker-diffusers-api/commit/933864893a35dfb9fa093b988a5b159af4e0a9ca))\n* **prime/update:** commit these useful utility scripts ([7b167c0](https://github.com/kiri-art/docker-diffusers-api/commit/7b167c0508e7a476d8c6719e056d6bdfa255e2d8))\n* **upsample:** return $meta for kiri ([b9dd6b7](https://github.com/kiri-art/docker-diffusers-api/commit/b9dd6b780005ad17090220fba99f0329b98f9c09))\n* **x_attn_kwargs:** only pass to pipeline if set ([3f1f980](https://github.com/kiri-art/docker-diffusers-api/commit/3f1f980930edb9bad28c6c026d31ca084887b442))\n\n\n### Features\n\n* **checkpoints:** use correct pipeline for \"inpaint\" in path ([16dd383](https://github.com/kiri-art/docker-diffusers-api/commit/16dd38327d291de29da012026a2ffcede0681526))\n* **loras:** ability to specify #?scale=0.1 -> cross_attn_kwargs ([747fc0d](https://github.com/kiri-art/docker-diffusers-api/commit/747fc0ddec1db91617fb01f4d7ef9b8291de221d))\n* **pytorch2:** bump deps, drop conda/xformers ([a3d8078](https://github.com/kiri-art/docker-diffusers-api/commit/a3d807896e2b0d831580b78be556fcc69be08353))\n* **sdxl,compel:** Support. AutoPipeline default, safety_check fix ([993be12](https://github.com/kiri-art/docker-diffusers-api/commit/993be124c2e5b0f04b1cf25ca285e3a6573ce19a))\n* **sdxl:** fix sd_xl, loras; ability to init load specific pipeline ([7e3af77](https://github.com/kiri-art/docker-diffusers-api/commit/7e3af77167b58481d3c974ae33c3991ef976fc28))\n* **textualInversion:** very early support ([2babd53](https://github.com/kiri-art/docker-diffusers-api/commit/2babd539a6fcb396bb1f323fe9c50cdccb91cf96))\n* **upsample:** initial RealESRGAN support for runtime downloads ([8929508](https://github.com/kiri-art/docker-diffusers-api/commit/8929508adea8cd0e50ccf79aaea2a13354f37fa8))\n\n# [1.6.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.5.0...v1.6.0) (2023-07-12)\n\n\n### Bug Fixes\n\n* **BaseStorage:** mv misplaced .query from BaseArchive to BaseStorage ([0c7a757](https://github.com/kiri-art/docker-diffusers-api/commit/0c7a757634cb62bacb3efda7f9a6e4b85bb3cb4e))\n* **conversion:** recognize \"safetensor\" anywhere in filename ([1ceab7d](https://github.com/kiri-art/docker-diffusers-api/commit/1ceab7dfb1d0d507b3b61f777453d81caf5190c2))\n* **deps:** bump diffusers to b9feed8, lock bitsandbytes==0.39.1 ([be1c322](https://github.com/kiri-art/docker-diffusers-api/commit/be1c32218cd0e312077de2b7a10b41f2f5be07e0))\n* **deps:** diffusers to 0.17.0 + latest commits, other packages ([a6e9db0](https://github.com/kiri-art/docker-diffusers-api/commit/a6e9db09382d972da3c6c08786ff92986e7585b7))\n* **pipelines:** pass revision/precision for community pipelines too ([20311cf](https://github.com/kiri-art/docker-diffusers-api/commit/20311cf51babf16609af1495585a4e9fca1f05e4))\n* **safety_checker:** drop DummySafetyChecker and just use None ([e4fbf22](https://github.com/kiri-art/docker-diffusers-api/commit/e4fbf225e0f09c8591f2537e3061977fad6386ed))\n\n\n### Features\n\n* **checkpoints:** support #fname query in HTTPStorage ([0cb839d](https://github.com/kiri-art/docker-diffusers-api/commit/0cb839db75f86c07d568b4a379bedba971340eb0))\n* **dreambooth:** update / merge in all upstream changes to date ([a40129a](https://github.com/kiri-art/docker-diffusers-api/commit/a40129a2b2f47282cc463d1249985d4b07ec16c9))\n* **loras:** use load_lora_weights (works with A1111 files too) ([7a64846](https://github.com/kiri-art/docker-diffusers-api/commit/7a6484642a11fc3f3de780d4627de2dd48607d89))\n* **storage:** allow #a=1&b=2 params; HTTP can use #fname=XXX ([4fe13ef](https://github.com/kiri-art/docker-diffusers-api/commit/4fe13ef7fbd4948e5f665e3d38a57430def561b8))\n\n# [1.5.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.4.0...v1.5.0) (2023-05-24)\n\n\n### Bug Fixes\n\n* **app:** async fixes for download, train_dreambooth ([0dcbd16](https://github.com/kiri-art/docker-diffusers-api/commit/0dcbd16c1a85a9f3fb867a28d66b00f0eccaba80))\n* **app:** diffusers callback cannot be async; use asyncio.run() ([7854649](https://github.com/kiri-art/docker-diffusers-api/commit/7854649011d370497690618fe3ea0e8ce2c79bc6))\n* **app:** up sanic RESPONSE_TIMEOUT from 1m to 1hr ([8e2003a](https://github.com/kiri-art/docker-diffusers-api/commit/8e2003afad8af93d4e1442138d6b7673e32af971))\n* **attn_procs:** apply workaround only for storage not hf repos ([b98710f](https://github.com/kiri-art/docker-diffusers-api/commit/b98710f144265df3d77a90bfb39d2dd30fbd8c96))\n* **attn_procs:** load non-safetensors attn_procs ourself ([072e7a3](https://github.com/kiri-art/docker-diffusers-api/commit/072e7a38f13d66b3e069427c318e16dcd5b6324d)), closes [/github.com/huggingface/diffusers/pull/2448#issuecomment-1453938119](https://github.com//github.com/huggingface/diffusers/pull/2448/issues/issuecomment-1453938119)\n* **deps:** pin websockets<11.0 for sanic ([33ae2f4](https://github.com/kiri-art/docker-diffusers-api/commit/33ae2f4c905c5e92aa9ff6cc2f61a3adb81b1b59))\n* **inference:** return $error NO_MODEL_ID vs later crash on None ([46ea977](https://github.com/kiri-art/docker-diffusers-api/commit/46ea977cea6e469059931d722df5a38a3f931d77))\n* **storage:** actually, always set self.status (default None) ([c309ca9](https://github.com/kiri-art/docker-diffusers-api/commit/c309ca92fd1038f89dae186e35cc732e5822c8c2))\n* **storage:** don't set self.status to None ([9b88b80](https://github.com/kiri-art/docker-diffusers-api/commit/9b88b8089c4063e63aab547ce945ebb1a94f2fd7))\n* **storage:** extract with dir= must not mutate dir (download, logs) ([b1f8f87](https://github.com/kiri-art/docker-diffusers-api/commit/b1f8f87756f61ae0aa61c3785911ab043f911d98))\n* **tests:** pin urlllib3 to < 2, avoids break in docker package ([ccf8231](https://github.com/kiri-art/docker-diffusers-api/commit/ccf823139ac0f379e2f27d8dd5921f5343f20f8a))\n\n\n### Features\n\n* **app:** run pipeline via asyncio.to_thread ([e87f7e7](https://github.com/kiri-art/docker-diffusers-api/commit/e87f7e772fa1f5f22957600572be60b150999095))\n* **attn_procs:** from_safetensors override, save .savetensors fname ([5fb6487](https://github.com/kiri-art/docker-diffusers-api/commit/5fb6487579d8b809c52f9451c68bcfcafecca0f0))\n* **cors:** add sanic-ext and set default cors-origin to \"*\" ([eb2a385](https://github.com/kiri-art/docker-diffusers-api/commit/eb2a385684a309557b637d7c03f2e8cda00137b0))\n* **diffusers:** bump to 0.15.0 + 2 weeks with lpw fix (9965cb5) ([77e9078](https://github.com/kiri-art/docker-diffusers-api/commit/77e907892b5b6b9b27aa75f5ec5732a81ba784d6))\n* **diffusers:** bump to latest diffusers, 0.14 + patches (see note) ([48a99a5](https://github.com/kiri-art/docker-diffusers-api/commit/48a99a532503bf9f8932f64ddf20d7b81aab765b))\n* **download:** async, status; download.py: use download_and_extract ([bb7434a](https://github.com/kiri-art/docker-diffusers-api/commit/bb7434a4e39d02dce5ecbf602fe6e41511481c12))\n* **HTTPStorage:** store filename from content-disposition ([2066c44](https://github.com/kiri-art/docker-diffusers-api/commit/2066c446ba058209d1f594a46a8af0188e6e82fa))\n* **loadModel:** send loadModel status ([db75740](https://github.com/kiri-art/docker-diffusers-api/commit/db75740177688e25bba4066d099a2c034dd3eb93))\n* **status:** initial status work ([d1cd39e](https://github.com/kiri-art/docker-diffusers-api/commit/d1cd39ea93e4c967be91ed59b8b05a6ce9f117da))\n* **storage:** support misc tar compression; progress ([a8c8337](https://github.com/kiri-art/docker-diffusers-api/commit/a8c8337da4b750f92f9712397293da20974aa385))\n* **stream_events:** stream send()'s to client too ([08daf4f](https://github.com/kiri-art/docker-diffusers-api/commit/08daf4fdca1f3ad23965e9bf14a3b66fc57279fd))\n\n# [1.4.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.3.0...v1.4.0) (2023-02-28)\n\n\n### Bug Fixes\n\n* **checkpoints:** new conversion pipeline + convert w/o MODEL_URL ([cd7f54d](https://github.com/kiri-art/docker-diffusers-api/commit/cd7f54db370462f6c3e7ecb37df791388a9ccd34))\n* **diffusers:** bump to latest commit (includes v0.13.1) ([400e3d7](https://github.com/kiri-art/docker-diffusers-api/commit/400e3d7b0897e966ba3c1cc04194aedde8746edf))\n* **diffusers:** bump to recent commit, includes misc LoRA fixes ([7249c30](https://github.com/kiri-art/docker-diffusers-api/commit/7249c307a9c2892a061398e75cd70965329c3ac6))\n* **loadModel:** pass revision arg too ([cd5f995](https://github.com/kiri-art/docker-diffusers-api/commit/cd5f995dad9123aa4ea066ad4b9d369ef01df06b))\n\n\n### Features\n\n* **attn_procs:** initial URL work (see notes) ([6348836](https://github.com/kiri-art/docker-diffusers-api/commit/6348836622da4a17fa0e423ca9b92ebb489b4793))\n* **callback:** if modelInput.callback_steps, send() current step ([2279de1](https://github.com/kiri-art/docker-diffusers-api/commit/2279de103d70614fbdee620024941dd1db81c436))\n* **gpu:** auto-detect GPU (CUDA/MPS/cpu), remove hard-coded ([#20](https://github.com/kiri-art/docker-diffusers-api/issues/20)) ([682a342](https://github.com/kiri-art/docker-diffusers-api/commit/682a34221f5b586fd0d8e9c0789201cb238cf225))\n* **lora:** callInput `attn_procs` to load LoRA's for inference ([cb54291](https://github.com/kiri-art/docker-diffusers-api/commit/cb542910fd234af0a02a862934bf5c090384500d))\n* **send:** set / override SEND_URL, SIGN_KEY via callInputs ([74b4c53](https://github.com/kiri-art/docker-diffusers-api/commit/74b4c53bd49691df087364959123cfd48e04ac59))\n\n# [1.3.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.2.2...v1.3.0) (2023-01-26)\n\n\n### Bug Fixes\n\n* **diffusers:** bump to v0.12.0 ([635d9d9](https://github.com/kiri-art/docker-diffusers-api/commit/635d9d97a010c49ef7875fcb4b43b668848ced0b))\n* **diffusers:** update to latest commit ([87632aa](https://github.com/kiri-art/docker-diffusers-api/commit/87632aa2c32faddfeb049fe969884b568066edd3))\n* **dreambooth:** bump diffusers, fixes fp16 mixed precision training ([0f5d5ff](https://github.com/kiri-art/docker-diffusers-api/commit/0f5d5ff2bf5b73260b9d60521389f0938f205219))\n* **dreambooth:** merge commits to v0.12.0 (NB: mixed-precision issue) ([88f04f8](https://github.com/kiri-art/docker-diffusers-api/commit/88f04f870814aa9baf2a7c09513dcc796070b814))\n* **pipelines:** fix clearPipelines() backport from cloud-cache ([9577f93](https://github.com/kiri-art/docker-diffusers-api/commit/9577f9344f0060edc185e32eadeb57e83551aa7f))\n* **requirements:** bump transformers,accelerate,safetensors & others ([aebcf65](https://github.com/kiri-art/docker-diffusers-api/commit/aebcf6562808a817e6ee29e88f178f22f54c861b))\n* **re:** use raw strings r\"\" for regexps ([41310c2](https://github.com/kiri-art/docker-diffusers-api/commit/41310c26bbc19069db492781313b162f0fc4d7d9))\n* **tests/lambda:** export HF_AUTH_TOKEN ([9f11e7b](https://github.com/kiri-art/docker-diffusers-api/commit/9f11e7b2f0d2a377a44b22d446274677bd025813))\n* **test:** shallow copy to avoid mutating base test inputs ([8c41167](https://github.com/kiri-art/docker-diffusers-api/commit/8c41167461308b14066be1472fd8957dc6cdd658))\n\n\n### Features\n\n* **downloads:** RUNTIME_DOWNLOAD from HF when no MODEL_URL given ([73784a1](https://github.com/kiri-art/docker-diffusers-api/commit/73784a1844ef2b14c628eb399bec0e52661df35c))\n\n## [1.2.2](https://github.com/kiri-art/docker-diffusers-api/compare/v1.2.1...v1.2.2) (2023-01-09)\n\n\n### Bug Fixes\n\n* **dreambooth:** runtime_dls path fix; integration tests ([ce3827f](https://github.com/kiri-art/docker-diffusers-api/commit/ce3827f6aabd5158c39c99ffae0358d832de2e39))\n* **loadModel:** revision = None if revision == \"\" else revision ([1773631](https://github.com/kiri-art/docker-diffusers-api/commit/1773631e292e28fae20b0a6c93406378aed85d47))\n\n## [1.2.1](https://github.com/kiri-art/docker-diffusers-api/compare/v1.2.0...v1.2.1) (2023-01-05)\n\n\n### Bug Fixes\n\n* **build-download:** support regular HF download not just cloud cache ([52edf6b](https://github.com/kiri-art/docker-diffusers-api/commit/52edf6b8e52cba4a03c8ea0f72b8fd1e69fa87ad))\n\n# [1.2.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.1.0...v1.2.0) (2023-01-04)\n\n\n### Features\n\n* **build:** separate MODEL_REVISION, MODEL_PRECISION, HF_MODEL_ID ([fa9dd16](https://github.com/kiri-art/docker-diffusers-api/commit/fa9dd16b7369d37f3997ef46581df471bca8e7c1))\n\n# [1.1.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.0.2...v1.1.0) (2023-01-04)\n\n\n### Features\n\n* **downloads:** allow HF_MODEL_ID call-arg (defauls to MODEL_ID) ([adaa7f6](https://github.com/kiri-art/docker-diffusers-api/commit/adaa7f67aba49058b2e52117e6eb0fed6417b773))\n* **downloads:** allow separate MODEL_REVISION and MODEL_PRECISION ([6edc821](https://github.com/kiri-art/docker-diffusers-api/commit/6edc821da1593f34e4502352dba8f2f4cd808e95))\n\n## [1.0.2](https://github.com/kiri-art/docker-diffusers-api/compare/v1.0.1...v1.0.2) (2023-01-01)\n\n\n### Bug Fixes\n\n* **diffusers:** bump to 2022-12-30 commit 62608a9 ([2f29165](https://github.com/kiri-art/docker-diffusers-api/commit/2f291655967a253b81da9f44c99d4ac68e1c8353))\n\n## [1.0.1](https://github.com/kiri-art/docker-diffusers-api/compare/v1.0.0...v1.0.1) (2022-12-31)\n\n\n### Bug Fixes\n\n* **ci:** different token, https auth ([ecd0b5d](https://github.com/kiri-art/docker-diffusers-api/commit/ecd0b5d8efe734693ff9647cfc2d0bc0b8f90e42))\n\n# 1.0.0 (2022-12-31)\n\n\n### Bug Fixes\n\n* **app:** clearPipelines() before loadModel() to free RAM ([ec45acf](https://github.com/kiri-art/docker-diffusers-api/commit/ec45acf7db7796682597d1d1c440d3742df84425))\n* **app:** init: don't process MODEL_ID if not RUNTIME_DOWNLOADS ([683677f](https://github.com/kiri-art/docker-diffusers-api/commit/683677f0bdbd49c11cb0310c7c365047b536a4f7))\n* **dockerfile:** bump diffusers to eb1abee693104dd45376dbddd614320f2a0beb24 ([1769330](https://github.com/kiri-art/docker-diffusers-api/commit/1769330d4ec1f5932591383daf078be0953accdc))\n* **downloads:** model_url, model_id should be optional ([9a19e7e](https://github.com/kiri-art/docker-diffusers-api/commit/9a19e7e1e742c46471f9a7e6fcebacea5f887d35))\n* **dreambooth:** don't crash on cleanup when no class_data_dir created ([36e64b1](https://github.com/kiri-art/docker-diffusers-api/commit/36e64b101bb12c7e09445f5958acaab1ab59a301))\n* **dreambooth:** enable mixed_precision training, default to fp16 ([0430d23](https://github.com/kiri-art/docker-diffusers-api/commit/0430d2380b5c6e5e43f2c8657017ba701bfaec41))\n* **gitScheduler:** fix deprecation warning s/from_config/from_pretrained/ ([92b2b43](https://github.com/kiri-art/docker-diffusers-api/commit/92b2b433bd9dfb4e1af1473cfa430e55bc83b170))\n* **pipelines:** community pipelines, set torch_dtype too ([0cc1b63](https://github.com/kiri-art/docker-diffusers-api/commit/0cc1b63f72f98ad9267cdc71707bb4b533ad303d))\n* **pipelines:** fix clearPipelines(), load model w/ correct precision ([3085412](https://github.com/kiri-art/docker-diffusers-api/commit/308541243c78cf528ebcd4c68900f5cdd52e6f8f))\n* **requirements:** bumps transformers from 4.22.2 to 4.25.1 ([b13b58c](https://github.com/kiri-art/docker-diffusers-api/commit/b13b58c89fcd30e90ebb58c193c803450db43ebd))\n* **s3:** incorrect value for tqdm causing crash ([9527ece](https://github.com/kiri-art/docker-diffusers-api/commit/9527ece90e4b5b4366f1c418d837dd659764203c))\n* **send:** container_id detection, use /containers/ to grep ([5c0606a](https://github.com/kiri-art/docker-diffusers-api/commit/5c0606a0fdfd9b1a410b6f96eff009da6b768dbe))\n* **tests:** default to DPMSolverMultistepScheduler and 20 steps ([a9c7bb0](https://github.com/kiri-art/docker-diffusers-api/commit/a9c7bb091821640a84d37d3090d365b7a54f2615))\n\n\n### Features\n\n* ability for custom config.yaml in CHECKPOINT_CONFIG_URL ([d2b507c](https://github.com/kiri-art/docker-diffusers-api/commit/d2b507ca225a033dda35897999e489541faecb8c))\n* add PyPatchMatch for outpainting support ([3675bd3](https://github.com/kiri-art/docker-diffusers-api/commit/3675bd31a12d7b1f9627e34f59b661ea7261c272))\n* **app:** don't track downloads in mem, check on disk ([51729e2](https://github.com/kiri-art/docker-diffusers-api/commit/51729e21440e4f0721b73ea497ddd2136306f11d))\n* **app:** runtime downloads with MODEL_URL ([7abc4ac](https://github.com/kiri-art/docker-diffusers-api/commit/7abc4aced15f4aec441d4c220f39e046d2e35179))\n* **app:** runtime downloads, re-use loaded model if requested again ([b84e822](https://github.com/kiri-art/docker-diffusers-api/commit/b84e822cacdb249693a301eb62a600ac9e0ee8f9))\n* **callInputs:** `MODEL_ID`, `PIPELINE`, `SCHEDULER` now optional ([ef420a1](https://github.com/kiri-art/docker-diffusers-api/commit/ef420a1022b3d80950e7df79f1aff006e775c313))\n* **cloud_cache:** normalize model_id and include precision ([ad1b2ef](https://github.com/kiri-art/docker-diffusers-api/commit/ad1b2efc60216c7a8854139ae816d78f6c4a9a19))\n* **diffusers:** bump to v0.10.12 and one commit after (6b68afd) ([ec9117b](https://github.com/kiri-art/docker-diffusers-api/commit/ec9117b747985b7b3d80a4211c4e7bf6253a24a1))\n* **diffusers:** bump to v0.9.0 ([0504d97](https://github.com/kiri-art/docker-diffusers-api/commit/0504d97e38eb85924ef7453c3c8690428f54870d))\n* **docker:** diffusers-api-base image, build, run.sh ([1cbfc4f](https://github.com/kiri-art/docker-diffusers-api/commit/1cbfc4f41b46ea8d38600ac6902cf5f095357344))\n* **dockerfile:** FROM_IMAGE build-arg to pick base image ([a0c37a6](https://github.com/kiri-art/docker-diffusers-api/commit/a0c37a6a87b300771f6ecf168b8bb1516caa5ab9))\n* **Dockerfile:** make SDv2 the default (+ some formatting cleanup) ([c1e73ef](https://github.com/kiri-art/docker-diffusers-api/commit/c1e73efcdb6e5c95d36c83f9d1398182a1b7e77e))\n* **dockerfile:** runtime downloads ([b40ae86](https://github.com/kiri-art/docker-diffusers-api/commit/b40ae868ce59ddb0232bcdb27ebb0a2c91068f51))\n* **Dockerfile:** SAFETENSORS_FAST_GPU ([62209be](https://github.com/kiri-art/docker-diffusers-api/commit/62209be9963f9699ba32ea7520a361545b55034e))\n* **download:** default_path as normalized_model_id.tar.zst ([5ad0d88](https://github.com/kiri-art/docker-diffusers-api/commit/5ad0d88b0b9b5a5a07596457c3bc83b7b32b25f5))\n* **download:** delete .zst file after uncompress ([ab25280](https://github.com/kiri-art/docker-diffusers-api/commit/ab25280125bc1ccc38a0a2588fc09e33a576f6b0))\n* **download:** record download timings ([7457e50](https://github.com/kiri-art/docker-diffusers-api/commit/7457e505c826c44d9f45a05fe486e819d442b4ca))\n* **downloads:** runtime checkpoint conversion ([2414cd9](https://github.com/kiri-art/docker-diffusers-api/commit/2414cd9e3ac232273a1f2441134c65c25d0f7b49))\n* **dreambooth:** save in safetensors format, tar up with -v ([5c3e86a](https://github.com/kiri-art/docker-diffusers-api/commit/5c3e86a8f99331c41c34b36c932b70e11f7b80b0))\n* **errors:** try...catch everything, return as JSON ([901679c](https://github.com/kiri-art/docker-diffusers-api/commit/901679c7829796dc585af25f658cd6ab9115c7e7))\n* **getScheduler:** make DPMSolverMultistepScheduler the default ([085d06f](https://github.com/kiri-art/docker-diffusers-api/commit/085d06f6b993a24b16521a1c3ee77d92289e04ed))\n* **k-diffusion:** add pip package for use in k-diffusion shedulers ([3e901ad](https://github.com/kiri-art/docker-diffusers-api/commit/3e901adc64f750f5501b5dd19d87d0a5e294de22))\n* **models:** store in ~/.cache/diffusers-api (volume support) ([8032ec1](https://github.com/kiri-art/docker-diffusers-api/commit/8032ec11b8f6590015110c9b89437f5619f2374c))\n* **pipelines:** allow calling of ALL PIPELINES (official+community) ([1ccbaad](https://github.com/kiri-art/docker-diffusers-api/commit/1ccbaad1f405b8e5d16ca1a9880cc1d279f6d3f9))\n* **pipelines:** initial community pipeline support ([7af45cf](https://github.com/kiri-art/docker-diffusers-api/commit/7af45cfdc4cbcc95c905834628775d0e8858509e))\n* **s3:** s3client(), file_exists() methods ([0308af9](https://github.com/kiri-art/docker-diffusers-api/commit/0308af910d07be6d912104663263663b086def9c))\n* **s3:** upload/download progress indicators ([76dd303](https://github.com/kiri-art/docker-diffusers-api/commit/76dd303a58a57b90ecc2c0038547b23b906ecca5))\n* **send:** prefer env var CONTAINER_ID if set to full docker uuid ([eec5112](https://github.com/kiri-art/docker-diffusers-api/commit/eec511252035b8205f5365f45abb5777c164cb57))\n* **send:** SEND_URL and SIGN_KEY now settable with build-vars ([01cf354](https://github.com/kiri-art/docker-diffusers-api/commit/01cf35461c5855a75651a30e3aeccb4ad1e9c8ac))\n* **test:** allow TEST_URL to override https://localhost:8000/ ([9b46387](https://github.com/kiri-art/docker-diffusers-api/commit/9b463872257c0a3ffae553765aed62a2df6af717))\n* **tests:** allow override BANANA_API_URL ([aca6aca](https://github.com/kiri-art/docker-diffusers-api/commit/aca6aca6e7ed46d0bf711548cea82a588fdd7d2a))\n\n# CHANGELOG\n\n* **NEXT MAIN**\n\n  * Callinputs `MODEL_ID`, `PIPELINE` and `SCHEDULER` are **now optional**.\n    If not specified, the default will be used, and returned in a `$meta`\n    key in the result.\n\n  * Tests: 1) Don't specify above defaults where possible, 2) Log exact\n    inputs sent to container, 3) Log the full result sent back,\n    substituting base64 image strings with their info, 4) format stack\n    traces on caught errors from container.\n\n* **NEXT MAIN (and already posted to forum)**\n\n  * **Latest diffusers, SDv2.1**.  All the latest goodness, and upgraded some\n    dependencies too.  Models are:\n\n    * `stabilityai/stable-diffusion-2-1-base` (512x512)\n    * `stabilityai/stable-diffusion-2-1` (768x768)\n\n  * **ALL THE PIPELINES**.  We no longer load a list of hard-coded pipelines\n    in `init()`.  Instead, we init and cache each on first use (for faster\n    first calls on cold boots), and, *all* pipelines, both official diffusers\n    and community pipelines, are available.\n    [Full details](https://banana-forums.dev/t/all-your-pipelines-are-belong-to-us/83)\n\n  * Dreambooth: Enable `mixed_precision` training, default to fp16.\n\n  * [Experimental] **[Runtime downloads](https://banana-forums.dev/t/runtime-downloads-dont-download-during-build/81/3)** (Dreambooth\n  only for now, more on the way)\n\n  * **S3**: Add upload/download progress indicators.\n\n  * Stable Diffusion has standardized **`image` instead of `init_image`** for\n    all pipelines.  Using `init_image` now shows a deprecation warning and\n    will be removed in future.\n\n  * **Changed `sd-base` to `diffusers-api`** as the default tag / name used\n    in the README examples and optional [./build][build script].\n\n  * **Much better error handling**.  We now `try...except` both the pipeline\n    run and entire `inference()` call, which will save you a trip to banana's\n    logs which don't always even show these errors and sometimes just leave\n    you with an unexplained stuck instance.  These kinds of errors are almost\n    always a result of problematic callInputs and modelInputs used for the\n    pipeline call, so finding them will be a lot easier now.\n\n* **2022-11-29**\n\n  * **Diffusers v0.9.0, Stable Diffusion v2.0**.  Models:\n      * `\"stabilityai/stable-diffusion-2\"` - trained on 768x768\n      * `\"stabilityai/stable-diffusion-2-base\"` - trained on 512x512\n      * `\"stabilityai/stable-diffusion-2-inpainting\"` - untested\n      * `\"\"stabilityai/stable-diffusion-x4-upscaler\"` - untested\n\n    > https://github.com/huggingface/diffusers/releases\n\n    **NB**: SDv2 does not include a safety_checker.  The model itself is\n    \"safe\" (it's much harder to create NSFW content).  Trying to \"turn off\"\n    the (non-existent) safety checker will throw an error, we'll handle this\n    more gracefully in a future release.  This also means you can safely\n    ignore this warning on loading:\n\n    ```\n    You have disabled the safety checker for\n    <class diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>\n    by passing safety_checker=None. Ensure that...\n    ```\n\n  * **DPMSolverMultistepScheduler**.  Docker-diffusers-api is simply a wrapper\n    around diffusers.  We support all the included schedulers out of the box,\n    as long as they can init themselves with default arguments.  So, the above\n    scheduler was already working, but we didn't mention it before.  I'll just\n    quote diffusers:\n\n    > DPMSolverMultistepScheduler is the firecracker diffusers implementation\n    of DPM-Solver++, a state-of-the-art scheduler that was contributed by one\n    of the authors of the paper. This scheduler is able to achieve great\n    quality in as few as 20 steps. It's a drop-in replacement for the default\n    Stable Diffusion scheduler, so you can use it to essentially half\n    generation times.\n\n  * **Storage Class / S3 support**.  We now have a generic storage class, which\n    allows for special URLs anywhere anywhere you can usually specify a URL,\n    e.g. `CHECKPOINT_URL`, `dest_url` (after dreambooth training), and the new\n    `MODEL_URL` (see below).  URLs like \"s3:///bucket/filename\" will work how\n    you expect, but definitely read [docs/storage.md](./docs/storage.md)\n    to understand the format better.  Note in particular the triple forwardslash\n    (\"///\") in the beginning to use the default S3 endpoint.\n\n  * **Dreambooth training**, working but still in development.  See\n    [this forum post](https://banana-forums.dev/t/dreambooth-training-first-look/36)\n    for more info.\n\n  * **`PRECISION`** build var, defaults to `\"fp16\"`, set to `\"\"` to use the model\n    defaults (generally fp32).\n\n  * **`CHECKPOINT_URL` conversion**:\n    * Crash / stop build if conversion fails (rather than unclear errors later on)\n    * Force `cpu` loading even for models that would otherwise default to GPU.\n      This fixes certain models that previously crashed in build stage (where GPU\n      is not available).\n    * `--extract-ema` on conversion since these are the more important weights for\n      inference.\n    * `CHECKPOINT_CONFIG_URL` now let's to specify a specific config file for \n      conversion, to use instead of SD's default `v1-inference.yaml`.\n\n  * **`MODEL_URL`**.  If your model is already in diffusers format, but you don't\n    host it on HuggingFace, you can now have it downloaded at build time.  At\n    this stage, it should be a `.tar.zst` file.  This is an *alternative* to\n    `CHECKPOINT_URL` which downloads a `.ckpt` file and converts to diffusers.\n\n  * **`test.py`**:\n    * New `--banana` arg to run the test on banana.  Set environment variables\n      `BANANA_API_KEY` and `BANANA_MODEL_KEY` first.\n    * You can now add to and override a test's default json payload with:\n      * `--model-arg prompt=\"hello\"`\n      * `--call-arg MODEL_ID=\"my-model\"`\n    * Support for extra timing data (e.g. dreambooth sends `train`\n      and `upload` timings).\n    * Quit after inference errors, don't keep looping.\n\n  * **Dev: better caching solution**.  No more unruly `root-cache` directory.  See\n    [CONTRIBUTING.md](./CONTRIBUTING.md) for more info.\n\n* **2022-11-08**\n\n  * **Much faster `init()` times!**  For `runwayml/stable-diffusion-v1-5`:\n\n    * Previously: 4.0s, now: 2.4s (40% speed gain)\n\n  * **Much faster `inference()` times!** Particularly from the 2nd inference onwards.\n    Here's a brief comparison of *inference* average times (for 512x512 x50 steps):\n\n    * [Cold] Previously: 3.8s, now: 3.3s (13% speed gain)\n    * [Warm] Previously: 3.2s, now: 2.1s (34% speed gain)\n\n  * **Improved `test.py`**, see [Testing](./README.md#testing)\n\n* **2022-11-05**\n\n  * Upgrade to **Diffusers v0.7.0**.  There is a lot of fun stuff in this release,\n    but notably for docker-diffusers-api TODAY (more fun stuff coming next week!),\n    we have **much faster init times** (via\n    [`fast_load`](https://github.com/huggingface/diffusers/commit/7482178162b779506a54538f2cf2565c8b88c597)\n    ) and the greatly anticipated support for the Euler schedulers (\n    [a1ea8c0](https://github.com/huggingface/diffusers/commit/a1ea8c01c31a44bf48f6a3b85ccabeb45ef6418f)\n    ).\n\n  * We now use the **full scheduler name** for `callInputs.SCHEDULER`.  `\"LMS\"`,\n    `\"DDIM\"`, `\"PNDM\"` all still work fine for now but give a deprecation warning\n    and will stop working in a future update.  The full list of supported schedulers\n    is: `LMSDiscreteScheduler`, `DDIMScheduler`, `PNDMScheduler`,\n    `EulerAncestralDiscreteScheduler`, `EulerDiscreteScheduler`.  These cover the\n    most commonly used / requested schedulers, but we already have code in place to\n    support every scheduler provided by diffusers, which will work in a later\n    diffusers release when they have better defaults.\n\n* **2022-10-24**\n\n  * **Fixed img2img and inpainting pipelines**.  To my great shame, in my rush to get\n    the new models out before the weekend, I inadvertently broke the above two models.\n    Please accept my sincere apology for any confusion this may have caused and\n    especially any of your wasted time in debugging this 🙇\n\n  * **Event logs now shown without `SEND_URL`**.  We optionally log useful info at the\n    start and end of `init()` and `inference()`.  Previously this was only logged if\n    `SEND_URL` was set, to send to an external REST API for logging.  But now, even if\n    we don't send it anywhere, we'll still log this useful info.  It now also logs\n    the `diffusers` version too.\n\n* **2022-10-21**\n\n  * **Stable Diffusion 1.5 released!!!**\n\n    Accept the license at:\n    [\"runwayml/stable-diffusion-v1-5\"](https://huggingface.co/runwayml/stable-diffusion-v1-5)\n\n    It's the new default model.\n\n  * **Official Stable Diffusion inpainting model**\n\n    Accept the license at:\n    [\"runwayml/stable-diffusion-inpainting\"](https://huggingface.co/runwayml/stable-diffusion-inpainting),\n\n    A few big caveats!\n\n    1) Different model - so back to a separate container for inpainting, also because:\n    2) New pipeline that can't share model struct with other pipelines\n       (see [diffusers#920](https://github.com/huggingface/diffusers/issues/920)).\n    3) Old pipeline is now called `StableDiffusionInpaintPipelineLegacy` (for sd-1.4)\n    4) `model_input` takes `image` now, and not `init_image` like the legacy model.\n    5) There is no `strength` parameter in the new model\n       (see [diffusers#920](https://github.com/huggingface/diffusers/issues/920)).\n\n  * Upgrade to **Diffusers v0.7.0.dev0**\n\n  * **Flash attention** now disabled by default.  1) Because it's built on\n    an older version of diffusers, but also because 2) I didn't succeed in\n    getting much improvement out of it.  Maybe someone else will have better\n    luck.  I think you need big batch sizes to really see the benefit, which\n    doesn't suit my use case.  But please anyone who figures anything out,\n    let us know.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# CONTRIBUTING\n\n*Tips for development*\n\n1. [General Hints](#general)\n1. [Development / Editor Setup](#editors)\n    1. [Visual Studio Code (vscode)](#vscode)\n1. [Testing](#testing)\n1. [Using Buildkit](#buildkit)\n1. [Local HTTP(S) Caching Proxy](#caching)\n1. [Local S3 Server](#local-s3-server)\n1. [Stop on Suspend](#stop-on-suspend)\n\n<a name=\"general\"></a>\n## General\n\n1. Run docker with `-it` to make it easier to stop container with `Ctrl-C`.\n1. If you get a `CUDA initialization: CUDA unknown error` after suspend,\n    just stop the container, `rmmod nvidia_uvm`, and restart.\n\n<a name=\"editors\"></a>\n## Editors\n\n<a name=\"vscode\"></a>\n### Visual Studio Code (recommended, WIP)\n\n*We're still writing this guide, let us know of any needed improvements*\n\nThis repo includes VSCode settings that allow for a) editing inside a docker container, b) tests and coverage (on save)\n\n1. Install from https://code.visualstudio.com/\n1. Install [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension.\n1. Open your docker-diffusers-api folder, you'll get a popup in the bottom right that a dev container environment was detected, click \"reload in container\"\n1. Look for the \"( ) Watch\" on status bar and click it so it changes to \"( ) XX Coverage\"\n\n**Live Development**\n\n1. **Run Task** (either Ctrl-Shift-P and \"Run Task\", or in Terminals, the Plus (\"+\") DROPDOWN selector and choose, \"Run Task...\" at the bottom)\n1. Choose **Watching Server**.  Port 8000 will be forwarded.  The server will be reloaded\non every file safe (make sure to give it enough time to fully load before sending another\nrequest, otherwise that request will hang).\n\n<a name=\"testing\"></a>\n## Testing\n\n1. **Unit testing**: exists but is sorely lacking for now.  If you use the\nrecommended editor setup above, it's probably working already.  However:\n\n1. **Integation / E2E**: cover most features used in production.\n`pytest -s tests/integration`.\nThe `-s` is optional but streams stdout so you can follow along.\nAdd also `-k test_name` to test a specific test.  E2E tests are LONG but you can\ngreatly reduce subsequent run time by following the steps below for a\n[Local HTTP(S) Caching Proxy](#caching) and [Local S3 Server](#local-s3-server).\n\nDocker-Diffusers-API follows Semantic Versioning.  We follow the\n[conventional commits](https://www.conventionalcommits.org/en/v1.0.0/)\nstandard.\n\n* On a commit to `dev`, if all CI tests pass, a new release is made to `:dev` tag.\n* On a commit to `main`, if all CI tests pass, a new release with appropriate\nmajor / minor / patch is made, based on appropriate tags in the commit history.\n\n<a name=\"buildkit\"></a>\n## Using BuildKit\n\nBuildkit is a docker extension that can really improve build speeds through\ncaching and parallelization.  You can enable and tweak it by adding:\n\n  `DOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain`\n\nvars before `docker build` (the `PROGRESS` var shows much more detailed\nbuild logs, which can be useful, but are much more verbose).  This is\nalready all setup in the the [build](./build) script.\n\n<a name=\"caching\"></a>\n## Local HTTP(S) Caching Proxy\n\nIf you're only editing e.g. `app.py`, there's no need to worry about caching\nand the docker layers work amazingly.  But, if you're constantly changing\ninstalled packages (apt, `requirements.txt`), `download.py`, etc, it's VERY\nhelpful to have a local cache:\n\n```bash\n# See all options at https://hub.docker.com/r/gadicc/squid-ssl-zero\n$ docker run -d -p 3128:3128 -p 3129:80 \\\n  --name squid --restart=always \\\n  -v /usr/local/squid:/usr/local/squid \\\n  gadicc/squid-ssl-zero\n```\n\nand then set the docker build args `proxy=1`, and `http_proxy` / `https_proxy`\nwith their respective values.\nThis is already all set up in the [build](./build) script.\n\n**You probably want to fine-tune /usr/local/squid/etc/squid.conf**.\n\nIt will be created after you first run `gadicc/squid-ssl-zero`.  You can then\nstop the container (`docker ps`, `docker stop container_id`), edit the file,\nand re-start (`docker start container_id`).  For now, try something like:\n\n```conf\ncache_dir ufs /usr/local/squid/cache 50000 16 256 # 50GB\nmaximum_object_size 20 GB\nrefresh_pattern .  52034400 50% 52034400 store-stale override-expire ignore-no-cache ignore-no-store ignore-private\n```\n\nbut ideally we can as a community create some rules that don't so\naggressively catch every single request.\n\n<a name=\"local-s3\"></a>\n## Local S3 server\n\nIf you're doing development around the S3 handling, it can be very useful to\nhave a local S3 server, especially due to the large size of models.  You\ncan set one up like this:\n\n```bash\n$ docker run -p 9000:9000 -p 9001:9001 \\\n  -v /usr/local/minio:/data quay.io/minio/minio \\\n  server /data --console-address \":9001\"\n```\n\nNow point a web browser to http://localhost:9001/, login with the default\nroot credentials `minioadmin:minioadmin` and create a bucket and credentials\nfor testing.  More info at https://hub.docker.com/r/minio/minio/.\n\nTypical policy:\n\n```json\n{\n    \"Version\": \"2012-10-17\",\n    \"Statement\": [\n        {\n            \"Sid\": \"VisualEditor0\",\n            \"Effect\": \"Allow\",\n            \"Action\": [\n                \"s3:PutObject\",\n                \"s3:GetObject\"\n            ],\n            \"Resource\": \"arn:aws:s3:::BUCKET_NAME/*\"\n        }\n    ]\n}\n```\n\nThen set the **build-arg** `AWS_S3_ENDPOINT_URL=\"http://172.17.0.1:9000\"`\nor as appropriate if you've changed the default docker network.\n\n<a name=\"stop-on-suspend\"></a>\n## Stop on Suspend\n\nMaybe it's just me, but frequently I'll have issues when suspending with\nthe container running (I guess its a CUDA issue), either a freeze on resume,\nor a stuck-forever defunct process.  I found it useful to automatically stop\nthe container / process on suspend.\n\nI'm running ArchLinux and set up a `systemd` suspend hook as described\n[here](https://wiki.archlinux.org/title/Power_management#Sleep_hooks), to\ncall a script, which contains:\n\n```bash\n# Stop a matching docker container\nPID=`docker ps -qf ancestor=gadicc/diffusers-api`\nif [ ! -z $PID ] ; then\n\techo \"Stopping diffusers-api pid $PID\"\n\tdocker stop $PID\nfi\n\n# For a VSCode devcontainer, just kill the watchmedo process.\nPID=`docker ps -qf volume=/home/dragon/root-cache`\nif [ ! -z $PID ] ; then\n\techo \"Stopping watchmedo in container $PID\"\n\tdocker exec $PID /bin/bash -c 'kill `pidof -sx watchmedo`'\nfi\n```\n"
  },
  {
    "path": "Dockerfile",
    "content": "ARG FROM_IMAGE=\"pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime\"\n# ARG FROM_IMAGE=\"gadicc/diffusers-api-base:python3.9-pytorch1.12.1-cuda11.6-xformers\"\n# You only need the -banana variant if you need banana's optimization\n# i.e. not relevant if you're using RUNTIME_DOWNLOADS\n# ARG FROM_IMAGE=\"gadicc/python3.9-pytorch1.12.1-cuda11.6-xformers-banana\"\nFROM ${FROM_IMAGE} as base\nENV FROM_IMAGE=${FROM_IMAGE}\n\n# Note, docker uses HTTP_PROXY and HTTPS_PROXY (uppercase)\n# We purposefully want those managed independently, as we want docker\n# to manage its own cache.  This is just for pip, models, etc.\nARG http_proxy\nARG https_proxy\nRUN if [ -n \"$http_proxy\" ] ; then \\\n    echo quit \\\n    | openssl s_client -proxy $(echo ${https_proxy} | cut -b 8-) -servername google.com -connect google.com:443 -showcerts \\\n    | sed 'H;1h;$!d;x; s/^.*\\(-----BEGIN CERTIFICATE-----.*-----END CERTIFICATE-----\\)\\n---\\nServer certificate.*$/\\1/' \\\n    > /usr/local/share/ca-certificates/squid-self-signed.crt ; \\\n    update-ca-certificates ; \\\n  fi\nARG REQUESTS_CA_BUNDLE=${http_proxy:+/usr/local/share/ca-certificates/squid-self-signed.crt}\n\nARG DEBIAN_FRONTEND=noninteractive\n\nARG TZ=UTC\nRUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone\n\nRUN apt-get update\nRUN apt-get install -yq apt-utils\nRUN apt-get install -yqq git zstd wget curl\n\nFROM base AS patchmatch\nARG USE_PATCHMATCH=0\nWORKDIR /tmp\nCOPY scripts/patchmatch-setup.sh .\nRUN sh patchmatch-setup.sh\n\nFROM base as output\nRUN mkdir /api\nWORKDIR /api\n\n# we use latest pip in base image\n# RUN pip3 install --upgrade pip\n\nADD requirements.txt requirements.txt\nRUN pip install -r requirements.txt\n\n# [Import] Add missing settings / Correct some dummy imports (#5036) - 2023-09-14\nARG DIFFUSERS_VERSION=\"3aa641289c995b3a0ce4ea895a76eb1128eff30c\"\nENV DIFFUSERS_VERSION=${DIFFUSERS_VERSION}\n\nRUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout ${DIFFUSERS_VERSION}\nWORKDIR /api\nRUN pip install -e diffusers\n\n# Set to true to NOT download model at build time, rather at init / usage.\nARG RUNTIME_DOWNLOADS=1\nENV RUNTIME_DOWNLOADS=${RUNTIME_DOWNLOADS}\n\n# TODO, to dda-bananana\n# ARG PIPELINE=\"StableDiffusionInpaintPipeline\"\nARG PIPELINE=\"ALL\"\nENV PIPELINE=${PIPELINE}\n\n# Deps for RUNNING (not building) earlier options\nARG USE_PATCHMATCH=0\nRUN if [ \"$USE_PATCHMATCH\" = \"1\" ] ; then apt-get install -yqq python3-opencv ; fi\nCOPY --from=patchmatch /tmp/PyPatchMatch PyPatchMatch\n\n# TODO, just include by default, and handle all deps in OUR requirements.txt\nARG USE_DREAMBOOTH=1\nENV USE_DREAMBOOTH=${USE_DREAMBOOTH}\n\nRUN if [ \"$USE_DREAMBOOTH\" = \"1\" ] ; then \\\n    # By specifying the same torch version as conda, it won't download again.\n    # Without this, it will upgrade torch, break xformers, make bigger image.\n    # bitsandbytes==0.40.0.post4 had failed cuda detection on dreambooth test.\n    pip install -r diffusers/examples/dreambooth/requirements.txt ; \\\n  fi\nRUN if [ \"$USE_DREAMBOOTH\" = \"1\" ] ; then apt-get install -yqq git-lfs ; fi\n\nARG USE_REALESRGAN=1\nRUN if [ \"$USE_REALESRGAN\" = \"1\" ] ; then apt-get install -yqq libgl1-mesa-glx libglib2.0-0 ; fi\nRUN if [ \"$USE_REALESRGAN\" = \"1\" ] ; then git clone https://github.com/xinntao/Real-ESRGAN.git ; fi\n# RUN if [ \"$USE_REALESRGAN\" = \"1\" ] ; then pip install numba==0.57.1 chardet ; fi\nRUN if [ \"$USE_REALESRGAN\" = \"1\" ] ; then pip install basicsr==1.4.2 facexlib==0.2.5 gfpgan==1.3.8 ; fi\nRUN if [ \"$USE_REALESRGAN\" = \"1\" ] ; then cd Real-ESRGAN && python3 setup.py develop ; fi\n\nCOPY api/ .\nEXPOSE 8000\n\nARG SAFETENSORS_FAST_GPU=1\nENV SAFETENSORS_FAST_GPU=${SAFETENSORS_FAST_GPU}\n\nCMD python3 -u server.py\n\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Banana, Gadi Cohen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# docker-diffusers-api (\"banana-sd-base\")\n\nDiffusers / Stable Diffusion in docker with a REST API, supporting various models, pipelines & schedulers.  Used by [kiri.art](https://kiri.art/), perfect for local, server & serverless.\n\n[![Docker](https://img.shields.io/docker/v/gadicc/diffusers-api?sort=semver)](https://hub.docker.com/r/gadicc/diffusers-api/tags) [![CircleCI](https://img.shields.io/circleci/build/github/kiri-art/docker-diffusers-api/split)](https://circleci.com/gh/kiri-art/docker-diffusers-api?branch=split) [![semantic-release](https://img.shields.io/badge/%20%20%F0%9F%93%A6%F0%9F%9A%80-semantic--release-e10079.svg)](https://github.com/semantic-release/semantic-release) [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/kiri-art/docker-diffusers-api)\n\nCopyright (c) Gadi Cohen, 2022.  MIT Licensed.\nPlease give credit and link back to this repo if you use it in a public project.\n\n## Features\n\n* Models: stable-diffusion, waifu-diffusion, and easy to add others (e.g. jp-sd)\n* Pipelines: txt2img, img2img and inpainting in a single container\n  ([all diffusers official and community pipelines](https://forums.kiri.art/t/all-your-pipelines-are-belong-to-us/83) are wrapped, but untested)\n* All model inputs supported, including setting nsfw filter per request\n* *Permute* base config to multiple forks based on yaml config with vars\n* Optionally send signed event logs / performance data to a REST endpoint / webhook.\n* Can automatically download a checkpoint file and convert to diffusers.\n* S3 support, dreambooth training.\n\nNote: This image was created for [kiri.art](https://kiri.art/).\nEverything is open source but there may be certain request / response\nassumptions.  If anything is unclear, please open an issue.\n\n## Important Notices\n\n* [Official `docker-diffusers-api` Forum](https://forums.kiri.art/c/docker-diffusers-api/16):\n  help, updates, discussion.\n* Subscribe (\"watch\") these forum topics for:\n  * [notable **`main`** branch updates](https://forums.kiri.art/t/official-releases-main-branch/35)\n  * [notable **`dev`** branch updates](https://forums.kiri.art/t/development-releases-dev-branch/53)\n* Always [check the CHANGELOG](./CHANGELOG.md) for important updates when upgrading.\n\n**Official help in our dedicated forum https://forums.kiri.art/c/docker-diffusers-api/16.**\n\n**This README refers to the in-development `dev` branch** and may\nreference features and fixes not yet in the published releases.\n\n**`v1` has not yet been officially released yet** but has been\nrunning well in production on kiri.art for almost a month.  We'd\nbe grateful for any feedback from early adopters to help make\nthis official.  For more details, see [Upgrading from v0 to\nv1](https://forums.kiri.art/t/wip-upgrading-from-v0-to-v1/116).\nPrevious releases available on the `dev-v0-final` and\n`main-v0-final` branches.\n\n**Currently only NVIDIA / CUDA devices are supported**.  Tracking\nApple / M1 support in issue\n[#20](https://github.com/kiri-art/docker-diffusers-api/issues/20).\n\n## Installation & Setup:\n\nSetup varies depending on your use case.\n\n1. **To run locally or on a *server*, with runtime downloads:**\n\n    `docker run --gpus all -p 8000:8000 -e HF_AUTH_TOKEN=$HF_AUTH_TOKEN gadicc/diffusers-api`.\n\n    See the [guides for various cloud providers](https://forums.kiri.art/t/running-on-other-cloud-providers/89/7).\n\n1. **To run *serverless*, include the model at build time:**\n\n    1. [docker-diffusers-api-build-download](https://github.com/kiri-art/docker-diffusers-api-build-download) (\n    [banana](https://forums.kiri.art/t/run-diffusers-api-on-banana-dev/103), others)\n    1. [docker-diffusers-api-runpod](https://github.com/kiri-art/docker-diffusers-api-runpod),\n    see the [guide](https://forums.kiri.art/t/run-diffusers-api-on-runpod-io/102)\n\n1. **Building from source**.\n\n    1. Fork / clone this repo.\n    1. `docker build -t gadicc/diffusers-api .`\n    1. See [CONTRIBUTING.md](./CONTRIBUTING.md) for more helpful hints.\n\n*Other configurations are possible but these are the most common cases*\n\nEverything is set via docker build-args or environment variables.\n\n## Usage:\n\nSee also [Testing](#testing) below.\n\nThe container expects an `HTTP POST` request to `/`, with a JSON body resembling the following:\n\n```json\n{\n  \"modelInputs\": {\n    \"prompt\": \"Super dog\",\n    \"num_inference_steps\": 50,\n    \"guidance_scale\": 7.5,\n    \"width\": 512,\n    \"height\": 512,\n    \"seed\": 3239022079\n  },\n  \"callInputs\": {\n    // You can leave these out to use the default\n    \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n    \"PIPELINE\": \"StableDiffusionPipeline\",\n    \"SCHEDULER\": \"LMSDiscreteScheduler\",\n    \"safety_checker\": true,\n  },\n}\n```\n\nIt's important to remember that `docker-diffusers-api` is primarily a wrapper\naround HuggingFace's\n[diffusers](https://huggingface.co/docs/diffusers/index) library.\n**Basic familiarity with `diffusers` is indespensible for a good experience\nwith `docker-diffusers-api`.**  Explaining some of the options above:\n\n* **modelInputs** - for the most part - are passed directly to the selected\ndiffusers pipeline unchanged.  So, for the default `StableDiffusionPipeline`,\nyou can see all options in the relevant pipeline docs for its\n[`__call__`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__) method.  The main exceptions are:\n\n    * Only valid JSON values can be given (strings, numbers, etc)\n    * **seed**, a number, is transformed into a `generator`.\n    * **images** are converted to / from base64 encoded strings.\n\n* **callInputs** affect which model, pipeline, scheduler and other lower\nlevel options are used to construct the final pipeline.  Notably:\n\n    * **`SCHEDULER`**: any scheduler included in diffusers should work out\n    the box, provided it can loaded with its default config and without\n    requiring any other explicit arguments at init time.  In any event,\n    the following schedulers are the most common and most well tested:\n    `DPMSolverMultistepScheduler` (fast!  only needs 20 steps!),\n    `LMSDiscreteScheduler`, `DDIMScheduler`, `PNDMScheduler`,\n    `EulerAncestralDiscreteScheduler`, `EulerDiscreteScheduler`.\n\n    * **`PIPELINE`**: the most common are\n    [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img),\n    [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img),\n    [`StableDiffusionInpaintPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint), and the community\n    [`lpw_stable_diffusion`](https://forums.kiri.art/t/lpw-stable-diffusion-pipeline-longer-prompts-prompt-weights/82)\n    which allows for long prompts (more than 77 tokens) and prompt weights\n    (things like `((big eyes))`, `(red hair:1.2)`, etc), and accepts a\n    `custom_pipeline_method` callInput with values `text2img` (\"text\", not \"txt\"),\n    `img2img` and `inpaint`.  See these links for all the possible `modelInputs`'s\n    that can be passed to the pipeline's `__call__` method.\n\n    * **`MODEL_URL`** (optional) can be used to retrieve the model from\n    locations other than HuggingFace, e.g. an `HTTP` server, S3-compatible\n    storage, etc.  For more info, see the\n    [storage docs](https://github.com/kiri-art/docker-diffusers-api/blob/dev/docs/storage.md)\n    and\n    [this post](https://forums.kiri.art/t/safetensors-our-own-optimization-faster-model-init/98)\n    for info on how to use and store optimized models from your own cloud.\n\n<a name=\"testing\"></a>\n## Examples and testing\n\nThere are also very basic examples in [test.py](./test.py), which you can view\nand call `python test.py` if the container is already running on port 8000.\nYou can also specify a specific test, change some options, and run against a\ndeployed banana image:\n\n```bash\n$ python test.py\nUsage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc]\n\n# Run against http://localhost:8000/ (Nvidia Quadro RTX 5000)\n$ python test.py txt2img\nRunning test: txt2img\nRequest took 5.9s (init: 3.2s, inference: 5.9s)\nSaved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png\n\n# Run against deployed banana image (Nvidia A100)\n$ export BANANA_API_KEY=XXX\n$ BANANA_MODEL_KEY=XXX python3 test.py --banana txt2img\nRunning test: txt2img\nRequest took 19.4s (init: 2.5s, inference: 3.5s)\nSaved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png\n\n# Note that 2nd runs are much faster (ignore init, that isn't run again)\nRequest took 3.0s (init: 2.4s, inference: 2.1s)\n```\n\nThe best example of course is https://kiri.art/ and it's\n[source code](https://github.com/kiri-art/stable-diffusion-react-nextjs-mui-pwa).\n\n## Help on [Official Forums](https://forums.kiri.art/c/docker-diffusers-api/16).\n\n## Adding other Models\n\nYou have two options.\n\n1. For a diffusers model, simply set `MODEL_ID` build-var / call-arg to the name\n  of the model hosted on HuggingFace, and it will be downloaded automatically at\n  build time.\n\n1. For a non-diffusers model, simply set the `CHECKPOINT_URL` build-var / call-arg\n  to the URL of a `.ckpt` file, which will be downloaded and converted to the diffusers\n  format automatically at build time.  `CHECKPOINT_CONFIG_URL` can also be set.\n\n## Troubleshooting\n\n* **403 Client Error: Forbidden for url**\n\n  Make sure you've accepted the license on the model card of the HuggingFace model\n  specified in `MODEL_ID`, and that you correctly passed `HF_AUTH_TOKEN` to the\n  container.\n\n## Event logs / web hooks / performance data\n\nSet `SEND_URL` (and optionally `SIGN_KEY`) environment variable(s) to send\nevent and timing data on `init`, `inference` and other start and end events.\nThis can either be used to log performance data, or for webhooks on event\nstart / finish.\n\nThe timing data is now returned in the response payload too, like this:\n`{ $timings: { init: timeInMs, inference: timeInMs } }`, with any other\nevents (such a `training`, `upload`, etc).\n\nYou can go to https://webhook.site/ and use the provided \"unique URL\"\nas your `SEND_URL` to see how it works, if you don't have your own\nREST endpoint (yet).\n\nIf `SIGN_KEY` is used, you can verify the signature like this (TypeScript):\n\n```ts\nimport crypto from \"crypto\";\n\nasync function handler(req: NextApiRequest, res: NextApiResponse) {\n  const data = req.body;\n\n  const containerSig = data.sig as string;\n  delete data.sig;\n\n  const ourSig = crypto\n    .createHash(\"md5\")\n    .update(JSON.stringify(data) + process.env.SIGN_KEY)\n    .digest(\"hex\");\n\n  const signatureIsValid = containerSig === ourSig;\n}\n```\n\nIf you send a callInput called `startRequestId`, it will get sent\nback as part of the send payload in most cases.\n\nYou can also set callInputs `SEND_URL` and `SIGN_KEY` to\nset or override these values on a per-request basis.\n\n## Acknowledgements\n\n* The container image is originally based on\n  https://github.com/bananaml/serverless-template-stable-diffusion.\n\n* [CompVis](https://github.com/CompVis),\n  [Stability AI](https://stability.ai/),\n  [LAION](https://laion.ai/)\n  and [RunwayML](https://runwayml.com/)\n  for their incredible time, work and efforts in creating Stable Diffusion,\n  and no less so, their decision to release it publicly with an open source\n  license.\n\n* [HuggingFace](https://huggingface.co/) - for their passion and inspiration\n  for making machine learning more accessibe to developers, and in particular,\n  their [Diffusers](https://github.com/huggingface/diffusers) library.\n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "api/app.py",
    "content": "import asyncio\nfrom sched import scheduler\nimport torch\n\nfrom torch import autocast\nfrom diffusers import __version__\nimport base64\nfrom io import BytesIO\nimport PIL\nimport json\nfrom loadModel import loadModel\nfrom send import send, getTimings, clearSession\nfrom status import status\nimport os\nimport numpy as np\nimport skimage\nimport skimage.measure\nfrom getScheduler import getScheduler, SCHEDULERS\nfrom getPipeline import (\n    getPipelineClass,\n    getPipelineForModel,\n    listAvailablePipelines,\n    clearPipelines,\n)\nimport re\nimport requests\nfrom download import download_model, normalize_model_id\nimport traceback\nfrom precision import MODEL_REVISION, MODEL_PRECISION\nfrom device import device, device_id, device_name\nfrom utils import Storage\nfrom hashlib import sha256\nfrom threading import Timer\nimport extras\nimport jxlpy\nfrom jxlpy import JXLImagePlugin\n\n\nfrom diffusers import (\n    StableDiffusionXLPipeline,\n    StableDiffusionXLImg2ImgPipeline,\n    StableDiffusionXLInpaintPipeline,\n    pipelines as diffusers_pipelines,\n    AutoencoderTiny,\n    AutoencoderKL,\n)\n\nfrom lib.textual_inversions import handle_textual_inversions\nfrom lib.prompts import prepare_prompts\nfrom lib.vars import (\n    RUNTIME_DOWNLOADS,\n    USE_DREAMBOOTH,\n    MODEL_ID,\n    PIPELINE,\n    HF_AUTH_TOKEN,\n    HOME,\n    MODELS_DIR,\n)\n\nif USE_DREAMBOOTH:\n    from train_dreambooth import TrainDreamBooth\nprint(os.environ.get(\"USE_PATCHMATCH\"))\nif os.environ.get(\"USE_PATCHMATCH\") == \"1\":\n    from PyPatchMatch import patch_match\n\ntorch.set_grad_enabled(False)\nalways_normalize_model_id = None\n\ntiny_vae = None\n\n\n# still working on this, not in use yet.\ndef tinyVae(origVae: AutoencoderKL):\n    global tiny_vae\n    if not tiny_vae:\n        tiny_vae = AutoencoderTiny.from_pretrained(\n            \"madebyollin/taesd\",\n            torch_dtype=torch.float16,\n            in_channels=origVae.config.in_channels,\n            out_channels=origVae.config.out_channels,\n            act_fn=origVae.config.act_fn,\n            latent_channels=origVae.config.latent_channels,\n            scaling_factor=origVae.config.scaling_factor,\n            force_upcast=origVae.config.force_upcast,\n        )\n        tiny_vae.to(\"cuda\")\n\n    return tiny_vae\n\n\n# Init is ran on server startup\n# Load your model to GPU as a global variable here using the variable name \"model\"\ndef init():\n    global model  # needed for bananna optimizations\n    global always_normalize_model_id\n\n    asyncio.run(\n        send(\n            \"init\",\n            \"start\",\n            {\n                \"device\": device_name,\n                \"hostname\": os.getenv(\"HOSTNAME\"),\n                \"model_id\": MODEL_ID,\n                \"diffusers\": __version__,\n            },\n        )\n    )\n\n    if MODEL_ID == \"ALL\" or RUNTIME_DOWNLOADS:\n        global last_model_id\n        last_model_id = None\n\n    if not RUNTIME_DOWNLOADS:\n        normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)\n        model_dir = os.path.join(MODELS_DIR, normalized_model_id)\n        if os.path.isdir(model_dir):\n            always_normalize_model_id = model_dir\n        else:\n            normalized_model_id = MODEL_ID\n\n        model = loadModel(\n            model_id=always_normalize_model_id or MODEL_ID,\n            load=True,\n            precision=MODEL_PRECISION,\n            revision=MODEL_REVISION,\n        )\n    else:\n        model = None\n\n    asyncio.run(send(\"init\", \"done\"))\n\n\ndef decodeBase64Image(imageStr: str, name: str) -> PIL.Image:\n    image = PIL.Image.open(BytesIO(base64.decodebytes(bytes(imageStr, \"utf-8\"))))\n    print(f'Decoded image \"{name}\": {image.format} {image.width}x{image.height}')\n    return image\n\n\ndef getFromUrl(url: str, name: str) -> PIL.Image:\n    response = requests.get(url)\n    image = PIL.Image.open(BytesIO(response.content))\n    print(f'Decoded image \"{name}\": {image.format} {image.width}x{image.height}')\n    return image\n\n\ndef truncateInputs(inputs: dict):\n    clone = inputs.copy()\n    if \"modelInputs\" in clone:\n        modelInputs = clone[\"modelInputs\"] = clone[\"modelInputs\"].copy()\n        for item in [\"init_image\", \"mask_image\", \"image\", \"input_image\"]:\n            if item in modelInputs:\n                modelInputs[item] = modelInputs[item][0:6] + \"...\"\n        if \"instance_images\" in modelInputs:\n            modelInputs[\"instance_images\"] = list(\n                map(lambda str: str[0:6] + \"...\", modelInputs[\"instance_images\"])\n            )\n    return clone\n\n\n# last_xformers_memory_efficient_attention = {}\nlast_attn_procs = None\nlast_lora_weights = None\ncross_attention_kwargs = None\n\n\n# Inference is ran for every server call\n# Reference your preloaded global model variable here.\nasync def inference(all_inputs: dict, response) -> dict:\n    global model\n    global pipelines\n    global last_model_id\n    global schedulers\n    # global last_xformers_memory_efficient_attention\n    global always_normalize_model_id\n    global last_attn_procs\n    global last_lora_weights\n    global cross_attention_kwargs\n\n    clearSession()\n\n    print(json.dumps(truncateInputs(all_inputs), indent=2))\n    model_inputs = all_inputs.get(\"modelInputs\", None)\n    call_inputs = all_inputs.get(\"callInputs\", None)\n    result = {\"$meta\": {}}\n\n    send_opts = {}\n    if call_inputs.get(\"SEND_URL\", None):\n        send_opts.update({\"SEND_URL\": call_inputs.get(\"SEND_URL\")})\n    if call_inputs.get(\"SIGN_KEY\", None):\n        send_opts.update({\"SIGN_KEY\": call_inputs.get(\"SIGN_KEY\")})\n    if response:\n        send_opts.update({\"response\": response})\n\n        async def sendStatusAsync():\n            await response.send(json.dumps(status.get()) + \"\\n\")\n\n        def sendStatus():\n            try:\n                asyncio.run(sendStatusAsync())\n                Timer(1.0, sendStatus).start()\n            except:\n                pass\n\n        Timer(1.0, sendStatus).start()\n\n    if model_inputs == None or call_inputs == None:\n        return {\n            \"$error\": {\n                \"code\": \"INVALID_INPUTS\",\n                \"message\": \"Expecting on object like { modelInputs: {}, callInputs: {} } but got \"\n                + json.dumps(all_inputs),\n            }\n        }\n\n    startRequestId = call_inputs.get(\"startRequestId\", None)\n\n    use_extra = call_inputs.get(\"use_extra\", None)\n    if use_extra:\n        extra = getattr(extras, use_extra, None)\n        if not extra:\n            return {\n                \"$error\": {\n                    \"code\": \"NO_SUCH_EXTRA\",\n                    \"message\": 'Requested \"'\n                    + use_extra\n                    + '\", available: \"'\n                    + '\", \"'.join(extras.keys())\n                    + '\"',\n                }\n            }\n        return await extra(\n            model_inputs,\n            call_inputs,\n            send_opts=send_opts,\n            startRequestId=startRequestId,\n        )\n\n    model_id = call_inputs.get(\"MODEL_ID\", None)\n    if not model_id:\n        if not MODEL_ID:\n            return {\n                \"$error\": {\n                    \"code\": \"NO_MODEL_ID\",\n                    \"message\": \"No callInputs.MODEL_ID specified, nor was MODEL_ID env var set.\",\n                }\n            }\n        model_id = MODEL_ID\n        result[\"$meta\"].update({\"MODEL_ID\": MODEL_ID})\n    normalized_model_id = model_id\n\n    if RUNTIME_DOWNLOADS:\n        hf_model_id = call_inputs.get(\"HF_MODEL_ID\", None)\n        model_revision = call_inputs.get(\"MODEL_REVISION\", None)\n        model_precision = call_inputs.get(\"MODEL_PRECISION\", None)\n        checkpoint_url = call_inputs.get(\"CHECKPOINT_URL\", None)\n        checkpoint_config_url = call_inputs.get(\"CHECKPOINT_CONFIG_URL\", None)\n        normalized_model_id = normalize_model_id(model_id, model_revision)\n        model_dir = os.path.join(MODELS_DIR, normalized_model_id)\n        pipeline_name = call_inputs.get(\"PIPELINE\", None)\n        if pipeline_name:\n            pipeline_class = getPipelineClass(pipeline_name)\n        if last_model_id != normalized_model_id:\n            # if not downloaded_models.get(normalized_model_id, None):\n            if not os.path.isdir(model_dir):\n                model_url = call_inputs.get(\"MODEL_URL\", None)\n                if not model_url:\n                    # return {\n                    #     \"$error\": {\n                    #         \"code\": \"NO_MODEL_URL\",\n                    #         \"message\": \"Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput\",\n                    #     }\n                    # }\n                    normalized_model_id = hf_model_id or model_id\n                await download_model(\n                    model_id=model_id,\n                    model_url=model_url,\n                    model_revision=model_revision,\n                    checkpoint_url=checkpoint_url,\n                    checkpoint_config_url=checkpoint_config_url,\n                    hf_model_id=hf_model_id,\n                    model_precision=model_precision,\n                    send_opts=send_opts,\n                    pipeline_class=pipeline_class if pipeline_name else None,\n                )\n                # downloaded_models.update({normalized_model_id: True})\n            clearPipelines()\n            cross_attention_kwargs = None\n            if model:\n                model.to(\"cpu\")  # Necessary to avoid a memory leak\n            await send(\n                \"loadModel\", \"start\", {\"startRequestId\": startRequestId}, send_opts\n            )\n            model = await asyncio.to_thread(\n                loadModel,\n                model_id=normalized_model_id,\n                load=True,\n                precision=model_precision,\n                revision=model_revision,\n                send_opts=send_opts,\n                pipeline_class=pipeline_class if pipeline_name else None,\n            )\n            await send(\n                \"loadModel\", \"done\", {\"startRequestId\": startRequestId}, send_opts\n            )\n            last_model_id = normalized_model_id\n            last_attn_procs = None\n            last_lora_weights = None\n    else:\n        if always_normalize_model_id:\n            normalized_model_id = always_normalize_model_id\n        print(\n            {\n                \"always_normalize_model_id\": always_normalize_model_id,\n                \"normalized_model_id\": normalized_model_id,\n            }\n        )\n\n    if MODEL_ID == \"ALL\":\n        if last_model_id != normalized_model_id:\n            clearPipelines()\n            cross_attention_kwargs = None\n            model = loadModel(normalized_model_id, send_opts=send_opts)\n            last_model_id = normalized_model_id\n    else:\n        if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:\n            return {\n                \"$error\": {\n                    \"code\": \"MODEL_MISMATCH\",\n                    \"message\": f'Model \"{model_id}\" not available on this container which hosts \"{MODEL_ID}\"',\n                    \"requested\": model_id,\n                    \"available\": MODEL_ID,\n                }\n            }\n\n    if PIPELINE == \"ALL\":\n        pipeline_name = call_inputs.get(\"PIPELINE\", None)\n        if not pipeline_name:\n            pipeline_name = \"AutoPipelineForText2Image\"\n            result[\"$meta\"].update({\"PIPELINE\": pipeline_name})\n\n        pipeline = getPipelineForModel(\n            pipeline_name,\n            model,\n            normalized_model_id,\n            model_revision=model_revision if RUNTIME_DOWNLOADS else MODEL_REVISION,\n            model_precision=model_precision if RUNTIME_DOWNLOADS else MODEL_PRECISION,\n        )\n        if not pipeline:\n            return {\n                \"$error\": {\n                    \"code\": \"NO_SUCH_PIPELINE\",\n                    \"message\": f'\"{pipeline_name}\" is not an official nor community Diffusers pipelines',\n                    \"requested\": pipeline_name,\n                    \"available\": listAvailablePipelines(),\n                }\n            }\n    else:\n        pipeline = model\n\n    scheduler_name = call_inputs.get(\"SCHEDULER\", None)\n    if not scheduler_name:\n        scheduler_name = \"DPMSolverMultistepScheduler\"\n        result[\"$meta\"].update({\"SCHEDULER\": scheduler_name})\n\n    pipeline.scheduler = getScheduler(normalized_model_id, scheduler_name)\n    if pipeline.scheduler == None:\n        return {\n            \"$error\": {\n                \"code\": \"INVALID_SCHEDULER\",\n                \"message\": \"\",\n                \"requeted\": call_inputs.get(\"SCHEDULER\", None),\n                \"available\": \", \".join(SCHEDULERS),\n            }\n        }\n\n    safety_checker = call_inputs.get(\"safety_checker\", True)\n    pipeline.safety_checker = (\n        model.safety_checker\n        if safety_checker and hasattr(model, \"safety_checker\")\n        else None\n    )\n    is_url = call_inputs.get(\"is_url\", False)\n    image_decoder = getFromUrl if is_url else decodeBase64Image\n\n    textual_inversions = call_inputs.get(\"textual_inversions\", [])\n    await handle_textual_inversions(textual_inversions, model, status=status)\n\n    # Better to use new lora_weights in next section\n    attn_procs = call_inputs.get(\"attn_procs\", None)\n    if attn_procs is not last_attn_procs:\n        if attn_procs:\n            raise Exception(\n                \"[REMOVED] Using `attn_procs` for LoRAs is no longer supported. \"\n                + \"Please use `lora_weights` instead.\"\n            )\n        last_attn_procs = attn_procs\n    #     if attn_procs:\n    #         storage = Storage(attn_procs, no_raise=True)\n    #         if storage:\n    #             hash = sha256(attn_procs.encode(\"utf-8\")).hexdigest()\n    #             attn_procs_from_safetensors = call_inputs.get(\n    #                 \"attn_procs_from_safetensors\", None\n    #             )\n    #             fname = storage.url.split(\"/\").pop()\n    #             if attn_procs_from_safetensors and not re.match(\n    #                 r\".safetensors\", attn_procs\n    #             ):\n    #                 fname += \".safetensors\"\n    #             if True:\n    #                 # TODO, way to specify explicit name\n    #                 path = os.path.join(\n    #                     MODELS_DIR, \"attn_proc--url_\" + hash[:7] + \"--\" + fname\n    #                 )\n    #             attn_procs = path\n    #             if not os.path.exists(path):\n    #                 storage.download_and_extract(path)\n    #         print(\"Load attn_procs \" + attn_procs)\n    #         # Workaround https://github.com/huggingface/diffusers/pull/2448#issuecomment-1453938119\n    #         if storage and not re.search(r\".safetensors\", attn_procs):\n    #             attn_procs = torch.load(attn_procs, map_location=\"cpu\")\n    #         pipeline.unet.load_attn_procs(attn_procs)\n    #     else:\n    #         print(\"Clearing attn procs\")\n    #         pipeline.unet.set_attn_processor(CrossAttnProcessor())\n\n    # Currently we only support a single string, but we should allow\n    # and array too in anticipation of multi-LoRA support in diffusers\n    # tracked at https://github.com/huggingface/diffusers/issues/2613.\n    lora_weights = call_inputs.get(\"lora_weights\", None)\n    lora_weights_joined = json.dumps(lora_weights)\n    if last_lora_weights != lora_weights_joined:\n        if last_lora_weights != None and last_lora_weights != \"[]\":\n            print(\"Unloading previous LoRA weights\")\n            pipeline.unload_lora_weights()\n\n        last_lora_weights = lora_weights_joined\n        cross_attention_kwargs = {}\n\n        if type(lora_weights) is not list:\n            lora_weights = [lora_weights] if lora_weights else []\n\n        if len(lora_weights) > 0:\n            for weights in lora_weights:\n                storage = Storage(weights, no_raise=True, status=status)\n                if storage:\n                    storage_query_fname = storage.query.get(\"fname\")\n                    storage_query_scale = (\n                        float(storage.query.get(\"scale\")[0])\n                        if storage.query.get(\"scale\")\n                        else 1\n                    )\n                    cross_attention_kwargs.update({\"scale\": storage_query_scale})\n                    # https://github.com/damian0815/compel/issues/42#issuecomment-1656989385\n                    pipeline._lora_scale = storage_query_scale\n                    if storage_query_fname:\n                        fname = storage_query_fname[0]\n                    else:\n                        hash = sha256(weights.encode(\"utf-8\")).hexdigest()\n                        fname = \"url_\" + hash[:7] + \"--\" + storage.url.split(\"/\").pop()\n                    cache_fname = \"lora_weights--\" + fname\n                    path = os.path.join(MODELS_DIR, cache_fname)\n                    if not os.path.exists(path):\n                        await asyncio.to_thread(storage.download_file, path)\n                    print(\"Load lora_weights `\" + weights + \"` from `\" + path + \"`\")\n                    pipeline.load_lora_weights(\n                        MODELS_DIR, weight_name=cache_fname, local_files_only=True\n                    )\n                else:\n                    print(\"Loading from huggingface not supported yet: \" + weights)\n                    # maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?\n                    # lora_model_id = \"sayakpaul/civitai-light-shadow-lora\"\n                    # lora_filename = \"light_and_shadow.safetensors\"\n                    # pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)\n    else:\n        print(\"No changes to LoRAs since last call\")\n\n    # TODO, generalize\n    mi_cross_attention_kwargs = model_inputs.get(\"cross_attention_kwargs\", None)\n    if mi_cross_attention_kwargs:\n        model_inputs.pop(\"cross_attention_kwargs\")\n        if isinstance(mi_cross_attention_kwargs, str):\n            if not cross_attention_kwargs:\n                cross_attention_kwargs = {}\n            cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs))\n        elif type(mi_cross_attention_kwargs) == dict:\n            if not cross_attention_kwargs:\n                cross_attention_kwargs = {}\n            cross_attention_kwargs.update(mi_cross_attention_kwargs)\n        else:\n            return {\n                \"$error\": {\n                    \"code\": \"INVALID_CROSS_ATTENTION_KWARGS\",\n                    \"message\": \"`cross_attention_kwargs` should be a dict or json string\",\n                }\n            }\n\n    print({\"cross_attention_kwargs\": cross_attention_kwargs})\n    if cross_attention_kwargs:\n        model_inputs.update({\"cross_attention_kwargs\": cross_attention_kwargs})\n\n    # Parse out your arguments\n    # prompt = model_inputs.get(\"prompt\", None)\n    # if prompt == None:\n    #     return {\"message\": \"No prompt provided\"}\n    #\n    #   height = model_inputs.get(\"height\", 512)\n    #  width = model_inputs.get(\"width\", 512)\n    # num_inference_steps = model_inputs.get(\"num_inference_steps\", 50)\n    # guidance_scale = model_inputs.get(\"guidance_scale\", 7.5)\n    # seed = model_inputs.get(\"seed\", None)\n    #   strength = model_inputs.get(\"strength\", 0.75)\n\n    if \"init_image\" in model_inputs:\n        model_inputs[\"init_image\"] = image_decoder(\n            model_inputs.get(\"init_image\"), \"init_image\"\n        )\n\n    if \"image\" in model_inputs:\n        model_inputs[\"image\"] = image_decoder(model_inputs.get(\"image\"), \"image\")\n\n    if \"mask_image\" in model_inputs:\n        model_inputs[\"mask_image\"] = image_decoder(\n            model_inputs.get(\"mask_image\"), \"mask_image\"\n        )\n\n    if \"instance_images\" in model_inputs:\n        model_inputs[\"instance_images\"] = list(\n            map(\n                lambda str: image_decoder(str, \"instance_image\"),\n                model_inputs[\"instance_images\"],\n            )\n        )\n\n    await send(\"inference\", \"start\", {\"startRequestId\": startRequestId}, send_opts)\n\n    # Run patchmatch for inpainting\n    if call_inputs.get(\"FILL_MODE\", None) == \"patchmatch\":\n        sel_buffer = np.array(model_inputs.get(\"init_image\"))\n        img = sel_buffer[:, :, 0:3]\n        mask = sel_buffer[:, :, -1]\n        img = patch_match.inpaint(img, mask=255 - mask, patch_size=3)\n        model_inputs[\"init_image\"] = PIL.Image.fromarray(img)\n        mask = 255 - mask\n        mask = skimage.measure.block_reduce(mask, (8, 8), np.max)\n        mask = mask.repeat(8, axis=0).repeat(8, axis=1)\n        model_inputs[\"mask_image\"] = PIL.Image.fromarray(mask)\n\n    # Turning on takes 3ms and turning off 1ms... don't worry, I've got your back :)\n    # x_m_e_a = call_inputs.get(\"xformers_memory_efficient_attention\", True)\n    # last_x_m_e_a = last_xformers_memory_efficient_attention.get(pipeline, None)\n    # if x_m_e_a != last_x_m_e_a:\n    #     if x_m_e_a == True:\n    #         print(\"pipeline.enable_xformers_memory_efficient_attention()\")\n    #         pipeline.enable_xformers_memory_efficient_attention()  # default on\n    #     elif x_m_e_a == False:\n    #         print(\"pipeline.disable_xformers_memory_efficient_attention()\")\n    #         pipeline.disable_xformers_memory_efficient_attention()\n    #     else:\n    #         return {\n    #             \"$error\": {\n    #                 \"code\": \"INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE\",\n    #                 \"message\": f\"x_m_e_a expects True or False, not: {x_m_e_a}\",\n    #                 \"requested\": x_m_e_a,\n    #                 \"available\": [True, False],\n    #             }\n    #         }\n    #     last_xformers_memory_efficient_attention.update({pipeline: x_m_e_a})\n\n    # Run the model\n    # with autocast(device_id):\n    # image = pipeline(**model_inputs).images[0]\n\n    if call_inputs.get(\"train\", None) == \"dreambooth\":\n        if not USE_DREAMBOOTH:\n            return {\n                \"$error\": {\n                    \"code\": \"TRAIN_DREAMBOOTH_NOT_AVAILABLE\",\n                    \"message\": 'Called with callInput { train: \"dreambooth\" } but built with USE_DREAMBOOTH=0',\n                }\n            }\n\n        if RUNTIME_DOWNLOADS:\n            if os.path.isdir(model_dir):\n                normalized_model_id = model_dir\n\n        torch.set_grad_enabled(True)\n        result = result | await asyncio.to_thread(\n            TrainDreamBooth,\n            normalized_model_id,\n            pipeline,\n            model_inputs,\n            call_inputs,\n            send_opts=send_opts,\n        )\n        torch.set_grad_enabled(False)\n        await send(\"inference\", \"done\", {\"startRequestId\": startRequestId}, send_opts)\n        result.update({\"$timings\": getTimings()})\n        return result\n\n    # Do this after dreambooth as dreambooth accepts a seed int directly.\n    seed = model_inputs.get(\"seed\", None)\n    if seed == None:\n        generator = torch.Generator(device=device)\n        generator.seed()\n    else:\n        generator = torch.Generator(device=device).manual_seed(seed)\n        del model_inputs[\"seed\"]\n\n    model_inputs.update({\"generator\": generator})\n\n    callback = None\n    if model_inputs.get(\"callback_steps\", None):\n\n        def callback(step: int, timestep: int, latents: torch.FloatTensor):\n            asyncio.run(\n                send(\n                    \"inference\",\n                    \"progress\",\n                    {\"startRequestId\": startRequestId, \"step\": step},\n                    send_opts,\n                )\n            )\n\n    else:\n        vae = pipeline.vae\n        # vae = tinyVae(vae)\n        scaling_factor = vae.config.scaling_factor\n        image_processor = pipeline.image_processor\n\n        def callback(step: int, timestep: int, latents: torch.FloatTensor):\n            status.update(\n                \"inference\", step / model_inputs.get(\"num_inference_steps\", 50)\n            )\n\n            # with torch.no_grad():\n            #     image = vae.decode(latents / scaling_factor, return_dict=False)[0]\n            #     image = image_processor.postprocess(image, output_type=\"pil\")[0]\n            #     image.save(f\"step_{step}_img0.png\")\n\n    is_sdxl = (\n        isinstance(model, StableDiffusionXLPipeline)\n        or isinstance(model, StableDiffusionXLImg2ImgPipeline)\n        or isinstance(model, StableDiffusionXLInpaintPipeline)\n    )\n\n    with torch.inference_mode():\n        custom_pipeline_method = call_inputs.get(\"custom_pipeline_method\", None)\n        print(\n            {\n                \"callback\": callback,\n                \"**model_inputs\": model_inputs,\n            },\n        )\n\n        if call_inputs.get(\"compel_prompts\", False):\n            prepare_prompts(pipeline, model_inputs, is_sdxl)\n\n        try:\n            async_pipeline = asyncio.to_thread(\n                getattr(pipeline, custom_pipeline_method)\n                if custom_pipeline_method\n                else pipeline,\n                callback=callback,\n                **model_inputs,\n            )\n            # if call_inputs.get(\"PIPELINE\") != \"StableDiffusionPipeline\":\n            #    # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1\n            #    # still broken in 0.5.1\n            #    with autocast(device_id):\n            #        images = (await async_pipeline).images\n            # else:\n            pipeResult = await async_pipeline\n            images = pipeResult.images\n\n        except Exception as err:\n            return {\n                \"$error\": {\n                    \"code\": \"PIPELINE_ERROR\",\n                    \"name\": type(err).__name__,\n                    \"message\": str(err),\n                    \"stack\": traceback.format_exc(),\n                }\n            }\n\n    images_base64 = []\n    image_format = call_inputs.get(\"image_format\", \"PNG\")\n    image_opts = (\n        {\"lossless\": True} if image_format == \"PNG\" or image_format == \"WEBP\" else {}\n    )\n    for image in images:\n        buffered = BytesIO()\n        image.save(buffered, format=image_format, **image_opts)\n        images_base64.append(base64.b64encode(buffered.getvalue()).decode(\"utf-8\"))\n\n    await send(\"inference\", \"done\", {\"startRequestId\": startRequestId}, send_opts)\n\n    # Return the results as a dictionary\n    if len(images_base64) > 1:\n        result = result | {\"images_base64\": images_base64}\n    else:\n        result = result | {\"image_base64\": images_base64[0]}\n\n    nsfw_content_detected = pipeResult.get(\"nsfw_content_detected\", None)\n    if nsfw_content_detected:\n        result = result | {\"nsfw_content_detected\": nsfw_content_detected}\n\n    # TODO, move and generalize in device.py\n    mem_usage = 0\n    if torch.cuda.is_available():\n        mem_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()\n\n    result = result | {\"$timings\": getTimings(), \"$mem_usage\": mem_usage}\n\n    return result\n"
  },
  {
    "path": "api/convert_to_diffusers.py",
    "content": "import os\nimport requests\nimport subprocess\nimport torch\nimport json\nfrom diffusers.pipelines.stable_diffusion.convert_from_ckpt import (\n    download_from_original_stable_diffusion_ckpt,\n)\nfrom diffusers.pipelines.stable_diffusion import (\n    StableDiffusionInpaintPipeline,\n)\nfrom utils import Storage\nfrom device import device_id\n\nMODEL_ID = os.environ.get(\"MODEL_ID\", None)\nCHECKPOINT_DIR = \"/root/.cache/checkpoints\"\nCHECKPOINT_URL = os.environ.get(\"CHECKPOINT_URL\", None)\nCHECKPOINT_CONFIG_URL = os.environ.get(\"CHECKPOINT_CONFIG_URL\", None)\nCHECKPOINT_ARGS = os.environ.get(\"CHECKPOINT_ARGS\", None)\n# _CONVERT_SPECIAL = os.environ.get(\"_CONVERT_SPECIAL\", None)\n\n\ndef main(\n    model_id: str,\n    checkpoint_url: str,\n    checkpoint_config_url: str,\n    checkpoint_args: dict = {},\n    path=None,\n):\n    if not path:\n        fname = checkpoint_url.split(\"/\").pop()\n        path = os.path.join(CHECKPOINT_DIR, fname)\n\n    if checkpoint_config_url and checkpoint_config_url != \"\":\n        storage = Storage(checkpoint_config_url)\n        configPath = CHECKPOINT_DIR + \"/\" + path + \"_config.yaml\"\n        print(f\"Downloading {checkpoint_config_url} to {configPath}...\")\n        storage.download_file(configPath)\n\n    # specialSrc = \"https://raw.githubusercontent.com/hafriedlander/diffusers/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py\"\n    # specialPath = CHECKPOINT_DIR + \"/\" + \"convert_special.py\"\n    # if _CONVERT_SPECIAL:\n    #     storage = Storage(specialSrc)\n    #     print(f\"Downloading {specialSrc} to {specialPath}\")\n    #     storage.download_file(specialPath)\n\n    # scriptPath = (\n    #     # specialPath\n    #     # if _CONVERT_SPECIAL\n    #     # else\n    #     \"./diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py\"\n    # )\n\n    print(\"Converting \" + path + \" to diffusers model \" + model_id + \"...\", flush=True)\n\n    # These are now in main requirements.txt.\n    # subprocess.run(\n    #     [\"pip\", \"install\", \"omegaconf\", \"pytorch_lightning\", \"tensorboard\"], check=True\n    # )\n    # Diffusers now uses requests instead, yay!\n    # subprocess.run([\"apt-get\", \"install\", \"-y\", \"wget\"], check=True)\n\n    # We can now specify this ourselves and don't need to modify the script.\n    # if device_id == \"cpu\":\n    #     subprocess.run(\n    #         [\n    #             \"sed\",\n    #             \"-i\",\n    #             # Force loading into CPU\n    #             \"s/torch.load(args.checkpoint_path)/torch.load(args.checkpoint_path, map_location=torch.device('cpu'))/\",\n    #             scriptPath,\n    #         ]\n    #     )\n    # # Nice to check but also there seems to be a race condition here which\n    # # needs further investigation.  Python docs are clear that subprocess.run()\n    # # will \"Wait for command to complete, then return a CompletedProcess instance.\"\n    # # But it really seems as though without the grep in the middle, the script is\n    # # run before sed completes, or maybe there's some FS level caching gotchas.\n    # subprocess.run(\n    #     [\n    #         \"grep\",\n    #         \"torch.load\",\n    #         scriptPath,\n    #     ],\n    #     check=True,\n    # )\n\n    # args = [\n    #     \"python3\",\n    #     scriptPath,\n    #     \"--extract_ema\",\n    #     \"--checkpoint_path\",\n    #     fname,\n    #     \"--dump_path\",\n    #     model_id,\n    # ]\n\n    # if checkpoint_config_url:\n    #     args.append(\"--original_config_file\")\n    #     args.append(configPath)\n\n    # subprocess.run(\n    #     args,\n    #     check=True,\n    # )\n\n    # Oh yay!  Diffusers abstracted this now, so much easier to use.\n    # But less tested.  Changed on 2023-02-18.  TODO, remove commented\n    # out code above once this has more usage.\n\n    # diffusers defaults\n    args = {\n        \"scheduler_type\": \"pndm\",\n    }\n\n    # our defaults\n    args.update(\n        {\n            \"checkpoint_path_or_dict\": path,\n            \"original_config_file\": configPath if checkpoint_config_url else None,\n            \"device\": device_id,\n            \"extract_ema\": True,\n            \"from_safetensors\": \"safetensor\" in path.lower(),\n        }\n    )\n\n    if \"inpaint\" in path or \"Inpaint\" in path:\n        args.update({\"pipeline_class\": StableDiffusionInpaintPipeline})\n\n    # user overrides\n    args.update(checkpoint_args)\n\n    pipe = download_from_original_stable_diffusion_ckpt(**args)\n    pipe.save_pretrained(model_id, safe_serialization=True)\n\n\nif __name__ == \"__main__\":\n    # response = requests.get(\n    #    \"https://github.com/huggingface/diffusers/raw/main/scripts/convert_original_stable_diffusion_to_diffusers.py\"\n    # )\n    # open(\"convert_original_stable_diffusion_to_diffusers.py\", \"wb\").write(\n    #    response.content\n    # )\n\n    if CHECKPOINT_URL and CHECKPOINT_URL != \"\":\n        checkpoint_args = json.loads(CHECKPOINT_ARGS) if CHECKPOINT_ARGS else {}\n        main(\n            MODEL_ID,\n            CHECKPOINT_URL,\n            CHECKPOINT_CONFIG_URL,\n            checkpoint_args=checkpoint_args,\n        )\n"
  },
  {
    "path": "api/device.py",
    "content": "import torch\n\nif torch.cuda.is_available():\n    print(\"[device] CUDA (Nvidia) detected\")\n    device_id = \"cuda\"\n    device_name = torch.cuda.get_device_name()\nelif torch.backends.mps.is_available():\n    print(\"[device] MPS (MacOS Metal, Apple M1, etc) detected\")\n    device_id = \"mps\"\n    device_name = \"MPS\"\nelse:\n    print(\"[device] CPU only - no GPU detected\")\n    device_id = \"cpu\"\n    device_name = \"CPU only\"\n\n    if not torch.backends.cuda.is_built():\n        print(\n            \"CUDA not available because the current PyTorch install was not \"\n            \"built with CUDA enabled.\"\n        )\n    if torch.backends.mps.is_built():\n        print(\n            \"MPS not available because the current MacOS version is not 12.3+ \"\n            \"and/or you do not have an MPS-enabled device on this machine.\"\n        )\n    else:\n        print(\n            \"MPS not available because the current PyTorch install was not \"\n            \"built with MPS enabled.\"\n        )\n\ndevice = torch.device(device_id)\n"
  },
  {
    "path": "api/download.py",
    "content": "# In this file, we define download_model\n# It runs during container build time to get model weights built into the container\n\nimport os\nfrom loadModel import loadModel, MODEL_IDS\nfrom diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom utils import Storage\nimport subprocess\nfrom pathlib import Path\nimport shutil\nfrom convert_to_diffusers import main as convert_to_diffusers\nfrom download_checkpoint import main as download_checkpoint\nfrom status import status\nimport asyncio\n\nUSE_DREAMBOOTH = os.environ.get(\"USE_DREAMBOOTH\")\nHF_AUTH_TOKEN = os.environ.get(\"HF_AUTH_TOKEN\")\nRUNTIME_DOWNLOADS = os.environ.get(\"RUNTIME_DOWNLOADS\")\n\nHOME = os.path.expanduser(\"~\")\nMODELS_DIR = os.path.join(HOME, \".cache\", \"diffusers-api\")\nPath(MODELS_DIR).mkdir(parents=True, exist_ok=True)\n\n\n# i.e. don't run during build\nasync def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):\n    if RUNTIME_DOWNLOADS:\n        from send import send as _send\n\n        await _send(type, status, payload, send_opts)\n\n\ndef normalize_model_id(model_id: str, model_revision):\n    normalized_model_id = \"models--\" + model_id.replace(\"/\", \"--\")\n    if model_revision:\n        normalized_model_id += \"--\" + model_revision\n    return normalized_model_id\n\n\nasync def download_model(\n    model_url=None,\n    model_id=None,\n    model_revision=None,\n    checkpoint_url=None,\n    checkpoint_config_url=None,\n    hf_model_id=None,\n    model_precision=None,\n    send_opts={},\n    pipeline_class=None,\n):\n    print(\n        \"download_model\",\n        {\n            \"model_url\": model_url,\n            \"model_id\": model_id,\n            \"model_revision\": model_revision,\n            \"hf_model_id\": hf_model_id,\n            \"checkpoint_url\": checkpoint_url,\n            \"checkpoint_config_url\": checkpoint_config_url,\n        },\n    )\n    hf_model_id = hf_model_id or model_id\n    normalized_model_id = model_id\n\n    # if model_url != \"\": # throws an error, useful to debug stdout/stderr order\n    if model_url:\n        normalized_model_id = normalize_model_id(model_id, model_revision)\n        print({\"normalized_model_id\": normalized_model_id})\n        filename = model_url.split(\"/\").pop()\n        if not filename:\n            filename = normalized_model_id + \".tar.zst\"\n        model_file = os.path.join(MODELS_DIR, filename)\n        storage = Storage(\n            model_url, default_path=normalized_model_id + \".tar.zst\", status=status\n        )\n        exists = storage.file_exists()\n        if exists:\n            model_dir = os.path.join(MODELS_DIR, normalized_model_id)\n            print(\"model_dir\", model_dir)\n            await asyncio.to_thread(storage.download_and_extract, model_file, model_dir)\n        else:\n            if checkpoint_url:\n                path = download_checkpoint(checkpoint_url)\n                convert_to_diffusers(\n                    model_id=model_id,\n                    checkpoint_url=checkpoint_url,\n                    checkpoint_config_url=checkpoint_config_url,\n                    path=path,\n                )\n            else:\n                print(\"Does not exist, let's try find it on huggingface\")\n                print(\n                    {\n                        \"model_precision\": model_precision,\n                        \"model_revision\": model_revision,\n                    }\n                )\n                # This would be quicker to just model.to(device) afterwards, but\n                # this conveniently logs all the timings (and doesn't happen often)\n                print(\"download\")\n                await send(\"download\", \"start\", {}, send_opts)\n                model = loadModel(\n                    hf_model_id,\n                    False,\n                    precision=model_precision,\n                    revision=model_revision,\n                    pipeline_class=pipeline_class,\n                )  # download\n                await send(\"download\", \"done\", {}, send_opts)\n\n            print(\"load\")\n            model = loadModel(\n                hf_model_id,\n                True,\n                precision=model_precision,\n                revision=model_revision,\n                pipeline_class=pipeline_class,\n            )  # load\n            # dir = \"models--\" + model_id.replace(\"/\", \"--\") + \"--dda\"\n            dir = os.path.join(MODELS_DIR, normalized_model_id)\n            model.save_pretrained(dir, safe_serialization=True)\n\n            # This is all duped from train_dreambooth, need to refactor TODO XXX\n            await send(\"compress\", \"start\", {}, send_opts)\n            subprocess.run(\n                f\"tar cvf - -C {dir} . | zstd -o {model_file}\",\n                shell=True,\n                check=True,  # TODO, rather don't raise and return an error in JSON\n            )\n\n            await send(\"compress\", \"done\", {}, send_opts)\n            subprocess.run([\"ls\", \"-l\", model_file])\n\n            await send(\"upload\", \"start\", {}, send_opts)\n            upload_result = storage.upload_file(model_file, filename)\n            await send(\"upload\", \"done\", {}, send_opts)\n            print(upload_result)\n            os.remove(model_file)\n\n            # leave model dir for future loads... make configurable?\n            # shutil.rmtree(dir)\n\n            # TODO, swap directories, inside HF's cache structure.\n\n    else:\n        if checkpoint_url:\n            path = download_checkpoint(checkpoint_url)\n            convert_to_diffusers(\n                model_id=model_id,\n                checkpoint_url=checkpoint_url,\n                checkpoint_config_url=checkpoint_config_url,\n                path=path,\n            )\n        else:\n            # do a dry run of loading the huggingface model, which will download weights at build time\n            loadModel(\n                model_id=hf_model_id,\n                load=False,\n                precision=model_precision,\n                revision=model_revision,\n                pipeline_class=pipeline_class,\n            )\n\n    # if USE_DREAMBOOTH:\n    # Actually we can re-use these from the above loaded model\n    # Will remove this soon if no more surprises\n    # for subfolder, model in [\n    #     [\"tokenizer\", CLIPTokenizer],\n    #     [\"text_encoder\", CLIPTextModel],\n    #     [\"vae\", AutoencoderKL],\n    #     [\"unet\", UNet2DConditionModel],\n    #     [\"scheduler\", DDPMScheduler]\n    # ]:\n    #     print(subfolder, model)\n    #     model.from_pretrained(\n    #         MODEL_ID,\n    #         subfolder=subfolder,\n    #         revision=revision,\n    #         use_auth_token=HF_AUTH_TOKEN,\n    #     )\n\n\nif __name__ == \"__main__\":\n    asyncio.run(\n        download_model(\n            model_url=os.environ.get(\"MODEL_URL\"),\n            model_id=os.environ.get(\"MODEL_ID\"),\n            hf_model_id=os.environ.get(\"HF_MODEL_ID\"),\n            model_revision=os.environ.get(\"MODEL_REVISION\"),\n            model_precision=os.environ.get(\"MODEL_PRECISION\"),\n            checkpoint_url=os.environ.get(\"CHECKPOINT_URL\"),\n            checkpoint_config_url=os.environ.get(\"CHECKPOINT_CONFIG_URL\"),\n        )\n    )\n"
  },
  {
    "path": "api/download_checkpoint.py",
    "content": "import os\nfrom utils import Storage\n\nCHECKPOINT_URL = os.environ.get(\"CHECKPOINT_URL\", None)\nCHECKPOINT_DIR = \"/root/.cache/checkpoints\"\n\n\ndef main(checkpoint_url: str):\n    if not os.path.isdir(CHECKPOINT_DIR):\n        os.makedirs(CHECKPOINT_DIR)\n\n    storage = Storage(checkpoint_url)\n    storage_query_fname = storage.query.get(\"fname\")\n    if storage_query_fname:\n        fname = storage_query_fname[0]\n    else:\n        fname = checkpoint_url.split(\"/\").pop()\n    path = os.path.join(CHECKPOINT_DIR, fname)\n\n    if not os.path.isfile(path):\n        storage.download_file(path)\n\n    return path\n\n\nif __name__ == \"__main__\":\n    if CHECKPOINT_URL:\n        main(CHECKPOINT_URL)\n"
  },
  {
    "path": "api/extras/__init__.py",
    "content": "from .upsample import upsample\n"
  },
  {
    "path": "api/extras/upsample/__init__.py",
    "content": "from .upsample import upsample\n"
  },
  {
    "path": "api/extras/upsample/models.py",
    "content": "upsamplers = {\n    \"RealESRGAN_x4plus\": {\n        \"name\": \"General - RealESRGANplus\",\n        \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth\",\n        \"filename\": \"RealESRGAN_x4plus.pth\",\n        \"net\": \"RRDBNet\",\n        \"initArgs\": {\n            \"num_in_ch\": 3,\n            \"num_out_ch\": 3,\n            \"num_feat\": 64,\n            \"num_block\": 23,\n            \"num_grow_ch\": 32,\n            \"scale\": 4,\n        },\n        \"netscale\": 4,\n    },\n    # \"RealESRNet_x4plus\": {\n    #     \"name\": \"\",\n    #     \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth\",\n    #     \"path\": \"weights/RealESRNet_x4plus.pth\",\n    # },\n    \"RealESRGAN_x4plus_anime_6B\": {\n        \"name\": \"Anime - anime6B\",\n        \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth\",\n        \"filename\": \"RealESRGAN_x4plus_anime_6B.pth\",\n        \"net\": \"RRDBNet\",\n        \"initArgs\": {\n            \"num_in_ch\": 3,\n            \"num_out_ch\": 3,\n            \"num_feat\": 64,\n            \"num_block\": 6,\n            \"num_grow_ch\": 32,\n            \"scale\": 4,\n        },\n        \"netscale\": 4,\n    },\n    # \"RealESRGAN_x2plus\": {\n    #     \"name\": \"\",\n    #     \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth\",\n    #     \"path\": \"weights/RealESRGAN_x2plus.pth\",\n    # },\n    # \"realesr-animevideov3\": {\n    #     \"name\": \"AnimeVideo - v3\",\n    #     \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth\",\n    #    \"path\": \"weights/realesr-animevideov3.pth\",\n    # },\n    \"realesr-general-x4v3\": {\n        \"name\": \"General - v3\",\n        # [, \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth\" ],\n        \"weights\": \"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth\",\n        \"filename\": \"realesr-general-x4v3.pth\",\n        \"net\": \"SRVGGNetCompact\",\n        \"initArgs\": {\n            \"num_in_ch\": 3,\n            \"num_out_ch\": 3,\n            \"num_feat\": 64,\n            \"num_conv\": 32,\n            \"upscale\": 4,\n            \"act_type\": \"prelu\",\n        },\n        \"netscale\": 4,\n    },\n}\n\nface_enhancers = {\n    \"GFPGAN\": {\n        \"name\": \"GFPGAN\",\n        \"weights\": \"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth\",\n        \"filename\": \"GFPGANv1.4.pth\",\n    },\n}\n\nmodels_by_type = {\n    \"upsamplers\": upsamplers,\n    \"face_enhancers\": face_enhancers,\n}\n"
  },
  {
    "path": "api/extras/upsample/upsample.py",
    "content": "import os\nimport asyncio\nfrom pathlib import Path\n\nimport base64\nfrom io import BytesIO\nimport PIL\nimport json\nimport cv2\nimport numpy as np\nimport torch\nimport torchvision\n\nfrom basicsr.archs.rrdbnet_arch import RRDBNet\nfrom realesrgan import RealESRGANer\nfrom realesrgan.archs.srvgg_arch import SRVGGNetCompact\nfrom gfpgan import GFPGANer\n\nfrom .models import models_by_type, upsamplers, face_enhancers\nfrom status import status\nfrom utils import Storage\nfrom send import send\n\nprint(\n    {\n        \"torch.__version__\": torch.__version__,\n        \"torchvision.__version__\": torchvision.__version__,\n    }\n)\n\nHOME = os.path.expanduser(\"~\")\nCACHE_DIR = os.path.join(HOME, \".cache\", \"diffusers-api\", \"upsample\")\n\n\ndef cache_path(filename):\n    return os.path.join(CACHE_DIR, filename)\n\n\nasync def assert_model_exists(src, filename, send_opts, opts={}):\n    dest = cache_path(filename) if not opts.get(\"absolutePath\", None) else filename\n    if not os.path.exists(dest):\n        await send(\"download\", \"start\", {}, send_opts)\n        storage = Storage(src, status=status)\n        # await storage.download_file(dest)\n        await asyncio.to_thread(storage.download_file, dest)\n        await send(\"download\", \"done\", {}, send_opts)\n\n\nasync def download_models(send_opts={}):\n    Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)\n\n    for type in models_by_type:\n        models = models_by_type[type]\n        for model_key in models:\n            model = models[model_key]\n            await assert_model_exists(model[\"weights\"], model[\"filename\"], send_opts)\n\n    Path(\"gfpgan/weights\").mkdir(parents=True, exist_ok=True)\n\n    await assert_model_exists(\n        \"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth\",\n        \"detection_Resnet50_Final.pth\",\n        send_opts,\n    )\n    await assert_model_exists(\n        \"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth\",\n        \"parsing_parsenet.pth\",\n        send_opts,\n    )\n\n    # hardcoded paths in xinntao/facexlib\n    filenames = [\"detection_Resnet50_Final.pth\", \"parsing_parsenet.pth\"]\n    for file in filenames:\n        if not os.path.exists(f\"gfpgan/weights/{file}\"):\n            os.symlink(cache_path(file), f\"gfpgan/weights/{file}\")\n\n\nnets = {\n    \"RRDBNet\": RRDBNet,\n    \"SRVGGNetCompact\": SRVGGNetCompact,\n}\n\nmodels = {}\n\n\nasync def upsample(model_inputs, call_inputs, send_opts={}, startRequestId=None):\n    global models\n\n    # TODO, only download relevant models for this request\n    await download_models()\n\n    model_id = call_inputs.get(\"MODEL_ID\", None)\n\n    if not model_id:\n        return {\n            \"$error\": {\n                \"code\": \"MISSING_MODEL_ID\",\n                \"message\": \"call_inputs.MODEL_ID is required, but not given.\",\n            }\n        }\n\n    model = models.get(model_id, None)\n    if not model:\n        model = models_by_type[\"upsamplers\"].get(model_id, None)\n        if not model:\n            return {\n                \"$error\": {\n                    \"code\": \"MISSING_MODEL\",\n                    \"message\": f'Model \"{model_id}\" not available on this container.',\n                    \"requested\": model_id,\n                    \"available\": '\"' + '\", \"'.join(models.keys()) + '\"',\n                }\n            }\n        else:\n            modelModel = nets[model[\"net\"]](**model[\"initArgs\"])\n            await send(\n                \"loadModel\",\n                \"start\",\n                {\"startRequestId\": startRequestId},\n                send_opts,\n            )\n            upsampler = RealESRGANer(\n                scale=model[\"netscale\"],\n                model_path=cache_path(model[\"filename\"]),\n                dni_weight=None,\n                model=modelModel,\n                tile=0,\n                tile_pad=10,\n                pre_pad=0,\n                half=True,\n            )\n            await send(\n                \"loadModel\",\n                \"done\",\n                {\"startRequestId\": startRequestId},\n                send_opts,\n            )\n            model.update({\"model\": modelModel, \"upsampler\": upsampler})\n            models.update({model_id: model})\n\n    upsampler = model[\"upsampler\"]\n\n    input_image = model_inputs.get(\"input_image\", None)\n    if not input_image:\n        return {\n            \"$error\": {\n                \"code\": \"NO_INPUT_IMAGE\",\n                \"message\": \"Missing required parameter `input_image`\",\n            }\n        }\n\n    if model_id == \"realesr-general-x4v3\":\n        denoise_strength = model_inputs.get(\"denoise_strength\", 1)\n        if denoise_strength != 1:\n            # wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')\n            # model_path = [model_path, wdn_model_path]\n            # upsampler = models[\"realesr-general-x4v3-denoise\"]\n            # upsampler.dni_weight = dni_weight\n            dni_weight = [denoise_strength, 1 - denoise_strength]\n            return \"TODO: denoise_strength\"\n\n    face_enhance = model_inputs.get(\"face_enhance\", False)\n    if face_enhance:\n        face_enhancer = models.get(\"GFPGAN\", None)\n        if not face_enhancer:\n            await send(\n                \"loadModel\",\n                \"start\",\n                {\"startRequestId\": startRequestId},\n                send_opts,\n            )\n            print(\"1) \" + cache_path(face_enhancers[\"GFPGAN\"][\"filename\"]))\n            face_enhancer = GFPGANer(\n                model_path=cache_path(face_enhancers[\"GFPGAN\"][\"filename\"]),\n                upscale=4,  # args.outscale,\n                arch=\"clean\",\n                channel_multiplier=2,\n                bg_upsampler=upsampler,\n            )\n            await send(\n                \"loadModel\",\n                \"done\",\n                {\"startRequestId\": startRequestId},\n                send_opts,\n            )\n            models.update({\"GFPGAN\": face_enhancer})\n\n    if face_enhance:  # Use GFPGAN for face enhancement\n        face_enhancer.bg_upsampler = upsampler\n\n    # image = decodeBase64Image(model_inputs.get(\"input_image\"))\n    image_str = base64.b64decode(model_inputs[\"input_image\"])\n    image_np = np.frombuffer(image_str, dtype=np.uint8)\n    # bytes = BytesIO(base64.decodebytes(bytes(model_inputs[\"input_image\"], \"utf-8\")))\n    img = cv2.imdecode(image_np, cv2.IMREAD_UNCHANGED)\n\n    await send(\"inference\", \"start\", {\"startRequestId\": startRequestId}, send_opts)\n\n    # Run the model\n    # with autocast(\"cuda\"):\n    #    image = pipeline(**model_inputs).images[0]\n    if face_enhance:\n        _, _, output = face_enhancer.enhance(\n            img, has_aligned=False, only_center_face=False, paste_back=True\n        )\n    else:\n        output, _rgb = upsampler.enhance(img, outscale=4)  # TODO outscale param\n\n    image_base64 = base64.b64encode(cv2.imencode(\".jpg\", output)[1]).decode()\n\n    await send(\"inference\", \"done\", {\"startRequestId\": startRequestId}, send_opts)\n\n    # Return the results as a dictionary\n    return {\"$meta\": {}, \"image_base64\": image_base64}\n"
  },
  {
    "path": "api/getPipeline.py",
    "content": "import time\nimport os, fnmatch\nfrom diffusers import (\n    DiffusionPipeline,\n    pipelines as diffusers_pipelines,\n)\nfrom precision import torch_dtype_from_precision\n\nHOME = os.path.expanduser(\"~\")\nMODELS_DIR = os.path.join(HOME, \".cache\", \"diffusers-api\")\n_pipelines = {}\n_availableCommunityPipelines = None\n\n\ndef listAvailablePipelines():\n    return (\n        list(\n            filter(\n                lambda key: key.endswith(\"Pipeline\"),\n                list(diffusers_pipelines.__dict__.keys()),\n            )\n        )\n        + availableCommunityPipelines()\n    )\n\n\ndef availableCommunityPipelines():\n    global _availableCommunityPipelines\n    if not _availableCommunityPipelines:\n        _availableCommunityPipelines = list(\n            map(\n                lambda s: s[0:-3],\n                fnmatch.filter(os.listdir(\"diffusers/examples/community\"), \"*.py\"),\n            )\n        )\n\n    return _availableCommunityPipelines\n\n\ndef clearPipelines():\n    \"\"\"\n    Clears the pipeline cache.  Important to call this when changing the\n    loaded model, as pipelines include references to the model and would\n    therefore prevent memory being reclaimed after unloading the previous\n    model.\n    \"\"\"\n    global _pipelines\n    _pipelines = {}\n\n\ndef getPipelineClass(pipeline_name: str):\n    if hasattr(diffusers_pipelines, pipeline_name):\n        return getattr(diffusers_pipelines, pipeline_name)\n    elif pipeline_name in availableCommunityPipelines():\n        return DiffusionPipeline\n\n\ndef getPipelineForModel(\n    pipeline_name: str, model, model_id, model_revision, model_precision\n):\n    \"\"\"\n    Inits a new pipeline, re-using components from a previously loaded\n    model.  The pipeline is cached and future calls with the same\n    arguments will return the previously initted instance.  Be sure\n    to call `clearPipelines()` if loading a new model, to allow the\n    previous model to be garbage collected.\n    \"\"\"\n    pipeline = _pipelines.get(pipeline_name)\n    if pipeline:\n        return pipeline\n\n    start = time.time()\n\n    if hasattr(diffusers_pipelines, pipeline_name):\n        pipeline_class = getattr(diffusers_pipelines, pipeline_name)\n        if hasattr(pipeline_class, \"from_pipe\"):\n            pipeline = pipeline_class.from_pipe(model)\n        elif hasattr(model, \"components\"):\n            pipeline = pipeline_class(**model.components)\n        else:\n            pipeline = getattr(diffusers_pipelines, pipeline_name)(\n                vae=model.vae,\n                text_encoder=model.text_encoder,\n                tokenizer=model.tokenizer,\n                unet=model.unet,\n                scheduler=model.scheduler,\n                safety_checker=model.safety_checker,\n                feature_extractor=model.feature_extractor,\n            )\n\n    elif pipeline_name in availableCommunityPipelines():\n        model_dir = os.path.join(MODELS_DIR, model_id)\n        if not os.path.isdir(model_dir):\n            model_dir = None\n\n        pipeline = DiffusionPipeline.from_pretrained(\n            model_dir or model_id,\n            revision=model_revision,\n            torch_dtype=torch_dtype_from_precision(model_precision),\n            custom_pipeline=\"./diffusers/examples/community/\" + pipeline_name + \".py\",\n            local_files_only=True,\n            **model.components,\n        )\n\n    if pipeline:\n        _pipelines.update({pipeline_name: pipeline})\n        diff = round((time.time() - start) * 1000)\n        print(f\"Initialized {pipeline_name} for {model_id} in {diff}ms\")\n        return pipeline\n"
  },
  {
    "path": "api/getScheduler.py",
    "content": "import torch\nimport os\nimport time\nfrom diffusers import schedulers as _schedulers\n\nHF_AUTH_TOKEN = os.getenv(\"HF_AUTH_TOKEN\")\nHOME = os.path.expanduser(\"~\")\nMODELS_DIR = os.path.join(HOME, \".cache\", \"diffusers-api\")\n\nSCHEDULERS = [\n    \"DPMSolverMultistepScheduler\",\n    \"LMSDiscreteScheduler\",\n    \"DDIMScheduler\",\n    \"PNDMScheduler\",\n    \"EulerAncestralDiscreteScheduler\",\n    \"EulerDiscreteScheduler\",\n]\n\nDEFAULT_SCHEDULER = os.getenv(\"DEFAULT_SCHEDULER\", SCHEDULERS[0])\n\n\n\"\"\"\n# This was a nice idea but until we have default init vars for all schedulers\n# via from_pretrained(), it's a no go.  In any case, loading a scheduler takes time\n# so better to init as needed and cache.\nisScheduler = re.compile(r\".+Scheduler$\")\nfor key, val in _schedulers.__dict__.items():\n    if isScheduler.match(key):\n        schedulers.update(\n            {\n                key: val.from_pretrained(\n                    MODEL_ID, subfolder=\"scheduler\", use_auth_token=HF_AUTH_TOKEN\n                )\n            }\n        )\n\"\"\"\n\n\ndef initScheduler(MODEL_ID: str, scheduler_id: str, download=False):\n    print(f\"Initializing {scheduler_id} for {MODEL_ID}...\")\n    start = time.time()\n    scheduler = getattr(_schedulers, scheduler_id)\n    if scheduler == None:\n        return None\n\n    model_dir = os.path.join(MODELS_DIR, MODEL_ID)\n    if not os.path.isdir(model_dir):\n        model_dir = None\n\n    inittedScheduler = scheduler.from_pretrained(\n        model_dir or MODEL_ID,\n        subfolder=\"scheduler\",\n        use_auth_token=HF_AUTH_TOKEN,\n        local_files_only=not download,\n    )\n    diff = round((time.time() - start) * 1000)\n    print(f\"Initialized {scheduler_id} for {MODEL_ID} in {diff}ms\")\n\n    return inittedScheduler\n\n\nschedulers = {}\n\n\ndef getScheduler(MODEL_ID: str, scheduler_id: str, download=False):\n    schedulersByModel = schedulers.get(MODEL_ID, None)\n    if schedulersByModel == None:\n        schedulersByModel = {}\n        schedulers.update({MODEL_ID: schedulersByModel})\n\n    # Check for use of old names\n    deprecated_map = {\n        \"LMS\": \"LMSDiscreteScheduler\",\n        \"DDIM\": \"DDIMScheduler\",\n        \"PNDM\": \"PNDMScheduler\",\n    }\n    scheduler_renamed = deprecated_map.get(scheduler_id, None)\n    if scheduler_renamed != None:\n        print(\n            f'[Deprecation Warning]: Scheduler \"{scheduler_id}\" is now '\n            f'called \"{scheduler_id}\".  Please rename as this will '\n            f\"stop working in a future release.\"\n        )\n        scheduler_id = scheduler_renamed\n\n    scheduler = schedulersByModel.get(scheduler_id, None)\n    if scheduler == None:\n        scheduler = initScheduler(MODEL_ID, scheduler_id, download)\n        schedulersByModel.update({scheduler_id: scheduler})\n\n    return scheduler\n"
  },
  {
    "path": "api/lib/__init__.py",
    "content": ""
  },
  {
    "path": "api/lib/prompts.py",
    "content": "from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType\n\n\ndef prepare_prompts(pipeline, model_inputs, is_sdxl):\n    textual_inversion_manager = DiffusersTextualInversionManager(pipeline)\n    if is_sdxl:\n        compel = Compel(\n            tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],\n            text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],\n            # diffusers has no ti in sdxl yet\n            # https://github.com/huggingface/diffusers/issues/4376#issuecomment-1659016141\n            # textual_inversion_manager=textual_inversion_manager,\n            truncate_long_prompts=False,\n            returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,\n            requires_pooled=[False, True],\n        )\n        conditioning, pooled = compel(model_inputs.get(\"prompt\"))\n        negative_conditioning, negative_pooled = compel(\n            model_inputs.get(\"negative_prompt\")\n        )\n        [\n            conditioning,\n            negative_conditioning,\n        ] = compel.pad_conditioning_tensors_to_same_length(\n            [conditioning, negative_conditioning]\n        )\n        model_inputs.update(\n            {\n                \"prompt\": None,\n                \"negative_prompt\": None,\n                \"prompt_embeds\": conditioning,\n                \"negative_prompt_embeds\": negative_conditioning,\n                \"pooled_prompt_embeds\": pooled,\n                \"negative_pooled_prompt_embeds\": negative_pooled,\n            }\n        )\n\n    else:\n        compel = Compel(\n            tokenizer=pipeline.tokenizer,\n            text_encoder=pipeline.text_encoder,\n            textual_inversion_manager=textual_inversion_manager,\n            truncate_long_prompts=False,\n        )\n        conditioning = compel(model_inputs.get(\"prompt\"))\n        negative_conditioning = compel(model_inputs.get(\"negative_prompt\"))\n        [\n            conditioning,\n            negative_conditioning,\n        ] = compel.pad_conditioning_tensors_to_same_length(\n            [conditioning, negative_conditioning]\n        )\n        model_inputs.update(\n            {\n                \"prompt\": None,\n                \"negative_prompt\": None,\n                \"prompt_embeds\": conditioning,\n                \"negative_prompt_embeds\": negative_conditioning,\n            }\n        )\n"
  },
  {
    "path": "api/lib/textual_inversions.py",
    "content": "import json\nimport re\nimport os\nimport asyncio\nfrom utils import Storage\nfrom .vars import MODELS_DIR\n\nlast_textual_inversions = None\nlast_textual_inversion_model = None\nloaded_textual_inversion_tokens = []\n\ntokenRe = re.compile(\n    r\"[#&]{1}fname=(?P<fname>[^\\.]+)\\.(?:pt|safetensors)(&token=(?P<token>[^&]+))?$\"\n)\n\n\ndef strMap(str: str):\n    match = re.search(tokenRe, str)\n    # print(match)\n    if match:\n        return match.group(\"token\") or match.group(\"fname\")\n\n\ndef extract_tokens_from_list(textual_inversions: list):\n    return list(map(strMap, textual_inversions))\n\n\nasync def handle_textual_inversions(textual_inversions: list, model, status):\n    global last_textual_inversions\n    global last_textual_inversion_model\n    global loaded_textual_inversion_tokens\n\n    textual_inversions_str = json.dumps(textual_inversions)\n    if (\n        textual_inversions_str != last_textual_inversions\n        or model is not last_textual_inversion_model\n    ):\n        if model is not last_textual_inversion_model:\n            loaded_textual_inversion_tokens = []\n            last_textual_inversion_model = model\n        # print({\"textual_inversions\": textual_inversions})\n        # tokens_to_load = extract_tokens_from_list(textual_inversions)\n        # print({\"tokens_loaded\": loaded_textual_inversion_tokens})\n        # print({\"tokens_to_load\": tokens_to_load})\n        #\n        # for token in loaded_textual_inversion_tokens:\n        #     if token not in tokens_to_load:\n        #         print(\"[TextualInversion] Removing uneeded token: \" + token)\n        #         del pipeline.tokenizer.get_vocab()[token]\n        #         # del pipeline.text_encoder.get_input_embeddings().weight.data[token]\n        #         pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))\n        #\n        # loaded_textual_inversion_tokens = tokens_to_load\n\n        last_textual_inversions = textual_inversions_str\n        for textual_inversion in textual_inversions:\n            storage = Storage(textual_inversion, no_raise=True, status=status)\n            if storage:\n                storage_query_fname = storage.query.get(\"fname\")\n                if storage_query_fname:\n                    fname = storage_query_fname[0]\n                else:\n                    fname = textual_inversion.split(\"/\").pop()\n                path = os.path.join(MODELS_DIR, \"textual_inversion--\" + fname)\n                if not os.path.exists(path):\n                    await asyncio.to_thread(storage.download_file, path)\n                print(\"Load textual inversion \" + path)\n                token = storage.query.get(\"token\", None)\n                if token not in loaded_textual_inversion_tokens:\n                    model.load_textual_inversion(\n                        path, token=token, local_files_only=True\n                    )\n                    loaded_textual_inversion_tokens.append(token)\n            else:\n                print(\"Load textual inversion \" + textual_inversion)\n                model.load_textual_inversion(textual_inversion)\n    else:\n        print(\"No changes to textual inversions since last call\")\n"
  },
  {
    "path": "api/lib/textual_inversions_test.py",
    "content": "import unittest\nfrom .textual_inversions import extract_tokens_from_list\n\n\nclass TextualInversionsTest(unittest.TestCase):\n    def test_extract_tokens_query_fname(self):\n        tis = [\"https://civitai.com/api/download/models/106132#fname=4nj0lie.pt\"]\n        tokens = extract_tokens_from_list(tis)\n        self.assertEqual(tokens[0], \"4nj0lie\")\n\n    def test_extract_tokens_query_token(self):\n        tis = [\n            \"https://civitai.com/api/download/models/106132#fname=4nj0lie.pt&token=4nj0lie\"\n        ]\n        tokens = extract_tokens_from_list(tis)\n        self.assertEqual(tokens[0], \"4nj0lie\")\n"
  },
  {
    "path": "api/lib/vars.py",
    "content": "import os\n\nRUNTIME_DOWNLOADS = os.getenv(\"RUNTIME_DOWNLOADS\") == \"1\"\nUSE_DREAMBOOTH = os.getenv(\"USE_DREAMBOOTH\") == \"1\"\nMODEL_ID = os.environ.get(\"MODEL_ID\")\nPIPELINE = os.environ.get(\"PIPELINE\")\nHF_AUTH_TOKEN = os.getenv(\"HF_AUTH_TOKEN\")\nHOME = os.path.expanduser(\"~\")\nMODELS_DIR = os.path.join(HOME, \".cache\", \"diffusers-api\")\n"
  },
  {
    "path": "api/loadModel.py",
    "content": "import torch\nimport os\nfrom diffusers import pipelines as _pipelines, AutoPipelineForText2Image\nfrom getScheduler import getScheduler, DEFAULT_SCHEDULER\nfrom precision import torch_dtype_from_precision\nfrom device import device\nimport time\n\nHF_AUTH_TOKEN = os.getenv(\"HF_AUTH_TOKEN\")\nPIPELINE = os.getenv(\"PIPELINE\")\nUSE_DREAMBOOTH = True if os.getenv(\"USE_DREAMBOOTH\") == \"1\" else False\nHOME = os.path.expanduser(\"~\")\nMODELS_DIR = os.path.join(HOME, \".cache\", \"diffusers-api\")\n\n\nMODEL_IDS = [\n    \"CompVis/stable-diffusion-v1-4\",\n    \"hakurei/waifu-diffusion\",\n    # \"hakurei/waifu-diffusion-v1-3\", - not as diffusers yet\n    \"runwayml/stable-diffusion-inpainting\",\n    \"runwayml/stable-diffusion-v1-5\",\n    \"stabilityai/stable-diffusion-2\"\n    \"stabilityai/stable-diffusion-2-base\"\n    \"stabilityai/stable-diffusion-2-inpainting\",\n]\n\n\ndef loadModel(\n    model_id: str,\n    load=True,\n    precision=None,\n    revision=None,\n    send_opts={},\n    pipeline_class=None,\n):\n    torch_dtype = torch_dtype_from_precision(precision)\n    if revision == \"\":\n        revision = None\n\n    print(\n        \"loadModel\",\n        {\n            \"model_id\": model_id,\n            \"load\": load,\n            \"precision\": precision,\n            \"revision\": revision,\n            \"pipeline_class\": pipeline_class,\n        },\n    )\n\n    if not pipeline_class:\n        pipeline_class = AutoPipelineForText2Image\n\n    pipeline = pipeline_class if PIPELINE == \"ALL\" else getattr(_pipelines, PIPELINE)\n    print(\"pipeline\", pipeline_class)\n\n    print(\n        (\"Loading\" if load else \"Downloading\")\n        + \" model: \"\n        + model_id\n        + (f\" ({revision})\" if revision else \"\")\n    )\n\n    scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)\n\n    model_dir = os.path.join(MODELS_DIR, model_id)\n    if not os.path.isdir(model_dir):\n        model_dir = None\n\n    from_pretrained = time.time()\n    model = pipeline.from_pretrained(\n        model_dir or model_id,\n        revision=revision,\n        torch_dtype=torch_dtype,\n        use_auth_token=HF_AUTH_TOKEN,\n        scheduler=scheduler,\n        local_files_only=load,\n        # Work around https://github.com/huggingface/diffusers/issues/1246\n        # low_cpu_mem_usage=False if USE_DREAMBOOTH else True,\n    )\n    from_pretrained = round((time.time() - from_pretrained) * 1000)\n\n    if load:\n        to_gpu = time.time()\n        model.to(device)\n        to_gpu = round((time.time() - to_gpu) * 1000)\n        print(f\"Loaded from disk in {from_pretrained} ms, to gpu in {to_gpu} ms\")\n    else:\n        print(f\"Downloaded in {from_pretrained} ms\")\n\n    return model if load else None\n"
  },
  {
    "path": "api/precision.py",
    "content": "import os\nimport torch\n\nDEPRECATED_PRECISION = os.getenv(\"PRECISION\")\nMODEL_PRECISION = os.getenv(\"MODEL_PRECISION\") or DEPRECATED_PRECISION\nMODEL_REVISION = os.getenv(\"MODEL_REVISION\")\n\nif DEPRECATED_PRECISION:\n    print(\"Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION\")\n    print(\"Your setup still works but in a future release, this will throw an error\")\n\nif MODEL_PRECISION and not MODEL_REVISION:\n    print(\"Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please\")\n    print(f'explicitly set MODEL_REVISION=\"{MODEL_PRECISION}\" if that\\'s what you')\n    print(\"want.\")\n\n\ndef revision_from_precision(precision=MODEL_PRECISION):\n    # return precision if precision else None\n    raise Exception(\"revision_from_precision no longer supported\")\n\n\ndef torch_dtype_from_precision(precision=MODEL_PRECISION):\n    if precision == \"fp16\":\n        return torch.float16\n    return None\n\n\ndef torch_dtype_from_precision(precision=MODEL_PRECISION):\n    if precision == \"fp16\":\n        return torch.float16\n    return None\n"
  },
  {
    "path": "api/send.py",
    "content": "import json\nimport os\nimport datetime\nimport time\nimport requests\nimport hashlib\nfrom requests_futures.sessions import FuturesSession\nfrom status import status as statusInstance\n\nprint()\nenviron = os.environ.copy()\nfor key in [\"AWS_ACCESS_KEY_ID\", \"AWS_SECRET_ACCESS_KEY\", \"HF_AUTH_TOKEN\"]:\n    if environ.get(key, None):\n        environ[key] = \"XXX\"\nprint(environ)\nprint()\n\n\ndef get_now():\n    return round(time.time() * 1000)\n\n\nSEND_URL = os.getenv(\"SEND_URL\")\nif SEND_URL == \"\":\n    SEND_URL = None\n\nSIGN_KEY = os.getenv(\"SIGN_KEY\", \"\")\nif SIGN_KEY == \"\":\n    SIGN_KEY = None\n\nfutureSession = FuturesSession()\n\ncontainer_id = os.getenv(\"CONTAINER_ID\")\nif not container_id:\n    with open(\"/proc/self/mountinfo\") as file:\n        line = file.readline().strip()\n        while line:\n            if \"/containers/\" in line:\n                container_id = line.split(\"/containers/\")[\n                    -1\n                ]  # Take only text to the right\n                container_id = container_id.split(\"/\")[0]  # Take only text to the left\n                break\n            line = file.readline().strip()\n\n\ninit_used = False\n\n\ndef clearSession(force=False):\n    global session\n    global init_used\n\n    if init_used or force:\n        session = {\"_ctime\": get_now()}\n    else:\n        init_used = True\n\n\ndef getTimings():\n    timings = {}\n    for key in session.keys():\n        if key == \"_ctime\":\n            continue\n        start = session[key].get(\"start\", None)\n        done = session[key].get(\"done\", None)\n        if start and done:\n            timings.update({key: session[key][\"done\"] - session[key][\"start\"]})\n        else:\n            timings.update({key: -1})\n    return timings\n\n\nasync def send(type: str, status: str, payload: dict = {}, opts: dict = {}):\n    now = get_now()\n    send_url = opts.get(\"SEND_URL\", SEND_URL)\n    sign_key = opts.get(\"SIGN_KEY\", SIGN_KEY)\n\n    if status == \"start\":\n        session.update({type: {\"start\": now, \"last_time\": now}})\n    elif status == \"done\":\n        session[type].update({\"done\": now, \"diff\": now - session[type][\"start\"]})\n    else:\n        session[type][\"last_time\"] = now\n\n    data = {\n        \"type\": type,\n        \"status\": status,\n        \"container_id\": container_id,\n        \"time\": now,\n        \"t\": now - session[\"_ctime\"],\n        \"tsl\": now - session[type][\"last_time\"],\n        \"payload\": payload,\n    }\n\n    if status == \"start\":\n        statusInstance.update(type, 0.0)\n    elif status == \"done\":\n        statusInstance.update(type, 1.0)\n\n    if send_url and sign_key:\n        input = json.dumps(data, separators=(\",\", \":\")) + sign_key\n        sig = hashlib.md5(input.encode(\"utf-8\")).hexdigest()\n        data[\"sig\"] = sig\n\n    print(datetime.datetime.now(), data)\n\n    if send_url:\n        futureSession.post(send_url, json=data)\n\n    response = opts.get(\"response\")\n    if response:\n        print(\"streaming above\")\n        await response.send(json.dumps(data) + \"\\n\")\n\n    # try:\n    #    requests.post(send_url, json=data)  # , timeout=0.0000000001)\n    # except requests.exceptions.ReadTimeout:\n    # except requests.exceptions.RequestException as error:\n    #    print(error)\n    #    pass\n\n\nclearSession(True)\n"
  },
  {
    "path": "api/server.py",
    "content": "# Do not edit if deploying to Banana Serverless\n# This file is boilerplate for the http server, and follows a strict interface.\n\n# Instead, edit the init() and inference() functions in app.py\n\nfrom sanic import Sanic, response\nfrom sanic_ext import Extend\nimport subprocess\nimport app as user_src\nimport traceback\nimport os\nimport json\n\n# We do the model load-to-GPU step on server startup\n# so the model object is available globally for reuse\nuser_src.init()\n\n# Create the http server app\nserver = Sanic(\"my_app\")\nserver.config.CORS_ORIGINS = os.getenv(\"CORS_ORIGINS\") or \"*\"\nserver.config.RESPONSE_TIMEOUT = 60 * 60  # 1 hour (training can be long)\nExtend(server)\n\n\n# Healthchecks verify that the environment is correct on Banana Serverless\n@server.route(\"/healthcheck\", methods=[\"GET\"])\ndef healthcheck(request):\n    # dependency free way to check if GPU is visible\n    gpu = False\n    out = subprocess.run(\"nvidia-smi\", shell=True)\n    if out.returncode == 0:  # success state on shell command\n        gpu = True\n\n    return response.json({\"state\": \"healthy\", \"gpu\": gpu})\n\n\n# Inference POST handler at '/' is called for every http call from Banana\n@server.route(\"/\", methods=[\"POST\"])\nasync def inference(request):\n    try:\n        all_inputs = response.json.loads(request.json)\n    except:\n        all_inputs = request.json\n\n    call_inputs = all_inputs.get(\"callInputs\", None)\n    stream_events = call_inputs and call_inputs.get(\"streamEvents\", 0) != 0\n\n    streaming_response = None\n    if stream_events:\n        streaming_response = await request.respond(content_type=\"application/x-ndjson\")\n\n    try:\n        output = await user_src.inference(all_inputs, streaming_response)\n    except Exception as err:\n        print(err)\n        output = {\n            \"$error\": {\n                \"code\": \"APP_INFERENCE_ERROR\",\n                \"name\": type(err).__name__,\n                \"message\": str(err),\n                \"stack\": traceback.format_exc(),\n            }\n        }\n\n    if stream_events:\n        await streaming_response.send(json.dumps(output) + \"\\n\")\n    else:\n        return response.json(output)\n\n\nif __name__ == \"__main__\":\n    server.run(host=\"0.0.0.0\", port=\"8000\", workers=1)\n"
  },
  {
    "path": "api/status.py",
    "content": "class Status:\n    def __init__(self):\n        self.type = \"init\"\n        self.progress = 0.0\n\n    def update(self, type, progress):\n        self.type = type\n        self.progress = progress\n\n    def get(self):\n        return {\"type\": self.type, \"progress\": self.progress}\n\n\nstatus = Status()\n"
  },
  {
    "path": "api/tests.py",
    "content": "from test import runTest\n\n\ndef test_memory_free_on_swap_model():\n    \"\"\"\n    Make sure memory is freed when swapping models at runtime.\n    \"\"\"\n    result = runTest(\n        \"txt2img\",\n        {},\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n            \"MODEL_PRECISION\": \"\",  # full precision\n            \"MODEL_URL\": \"s3://\",\n        },\n        {\"num_inference_steps\": 1},\n    )\n    mem_usage = list()\n    mem_usage.append(result[\"$mem_usage\"])\n    result = runTest(\n        \"txt2img\",\n        {},\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n            \"MODEL_PRECISION\": \"fp16\",  # half precision\n            \"MODEL_URL\": \"s3://\",\n        },\n        {\"num_inference_steps\": 1},\n    )\n    mem_usage.append(result[\"$mem_usage\"])\n\n    print({\"mem_usage\": mem_usage})\n    # Assert that less memory used when unloading fp32 model and\n    # loading the fp16 variant in its place\n    assert mem_usage[1] < mem_usage[0]\n"
  },
  {
    "path": "api/train_dreambooth.py",
    "content": "# Based on https://github.com/huggingface/diffusers/commits/main/examples/dreambooth/train_dreambooth.py\n# Synced to commit b9feed87958c27074b0618cc543696c05f58e2c9 on 2023-07-12\n\n# Reasons for not using that file directly:\n#\n#   1) Use our already loded model from `init()`\n#   2) Callback to run after every iteration\n\n# Deps\n\n#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\nimport argparse\nimport gc\nimport hashlib\nimport itertools\nimport logging\nimport math\nimport os\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\n\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, model_info, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n# DDA\nfrom send import send as _send\nfrom utils import Storage\nimport subprocess\nimport re\nimport shutil\nimport asyncio\n\n# Our original code in docker-diffusers-api:\n\nHF_AUTH_TOKEN = os.getenv(\"HF_AUTH_TOKEN\")\n\n\ndef send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):\n    asyncio.run((_send(type, status, payload, send_opts)))\n\n\ndef TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opts):\n    # required inputs: instance_images instance_prompt\n\n    params = {\n        # Defaults\n        \"pretrained_model_name_or_path\": model_id,  # DDA, TODO\n        # Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\n        # float32 precision.\n        \"revision\": None,\n        \"tokenizer_name\": None,\n        \"instance_data_dir\": \"instance_data_dir\",  # DDA TODO\n        \"class_data_dir\": \"class_data_dir\",  # DDA, was: None,\n        # instance_prompt\n        \"class_prompt\": None,\n        \"with_prior_preservation\": False,\n        \"prior_loss_weight\": 1.0,\n        \"num_class_images\": 100,\n        \"output_dir\": \"text-inversion-model\",\n        \"seed\": None,\n        \"resolution\": 512,\n        # Whether to center crop the input images to the resolution. If not set, the images will be randomly\n        # cropped. The images will be resized to the resolution first before cropping.\n        \"center_crop\": False,\n        # Whether to train the text encoder. If set, the text encoder should be float32 precision.\n        \"train_text_encoder\": None,\n        \"train_batch_size\": 1,  # DDA, was: 4\n        \"sample_batch_size\": 1,  # DDA, was: 4,\n        \"num_train_epochs\": 1,\n        \"max_train_steps\": 800,  # DDA, was: None,\n        # Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`.\n        # In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\n        # Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\n        # See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\n        # instructions.\n        \"checkpointing_steps\": 1000000000,  # DDA, was: 500\n        # Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\n        # See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\n        # for more details\n        \"checkpoints_total_limit\": None,\n        \"resume_from_checkpoint\": None,\n        \"gradient_accumulation_steps\": 1,\n        \"gradient_checkpointing\": True,  # DDA was: None (needed for 16GB)\n        \"learning_rate\": 5e-6,\n        \"scale_lr\": False,\n        \"lr_scheduler\": \"constant\",\n        \"lr_warmup_steps\": 0,  # DDA, was: 500,\n        \"lr_num_cycles\": 1,\n        # Power factor of the polynomial scheduler\n        \"lr_power\": 1.0,\n        \"use_8bit_adam\": True,  # DDA, was: None (needed for 16GB)\n        # Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\n        \"dataloader_num_workers\": 0,\n        \"adam_beta1\": 0.9,\n        \"adam_beta2\": 0.999,\n        \"adam_weight_decay\": 1e-6,\n        \"adam_epsilon\": 1e-08,\n        \"max_grad_norm\": 1.0,\n        \"push_to_hub\": None,\n        \"hub_token\": HF_AUTH_TOKEN,\n        \"hub_model_id\": None,\n        \"logging_dir\": \"logs\",\n        # Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\n        # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n        \"allow_tf32\": None,\n        # The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`\n        # (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.\n        \"report_to\": \"tensorboard\",\n        # A prompt that is used during validation to verify that the model is learning.\n        \"validation_prompt\": None,\n        # Number of images that should be generated during validation with `validation_prompt`\n        \"num_validation_images\": 4,\n        # Run validation every X steps. Validation consists of running the prompt\n        # `args.validation_prompt` multiple times: `args.num_validation_images`\n        # and logging the images.\n        \"validation_steps\": 100,\n        \"mixed_precision\": \"fp16\",  # DDA, was: None\n        # Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\n        # 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\n        \"prior_generation_precision\": None,  # \"no\", \"fp32\", \"fp16\", \"bf16\"\n        \"local_rank\": -1,\n        \"enable_xformers_memory_efficient_attention\": None,\n        # Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\n        # behaviors, so disable this argument if it causes any problems. More info:\n        # https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\n        \"set_grads_to_none\": None,\n        # Fine-tuning against a modified noise\"\n        # See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information.\n        \"offset_noise\": False,\n        # Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.\n        \"pre_compute_text_embeddings\": False,\n        # The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.\"\n        \"tokenizer_max_length\": None,\n        # Whether to use attention mask for the text encoder\n        \"text_encoder_use_attention_mask\": False,\n        # Set to not save text encoder\n        \"skip_save_text_encoder\": False,\n        # Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.\n        \"validation_images\": None,\n        # The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.\n        \"class_labels_conditioning\": None,\n    }\n\n    instance_images = model_inputs[\"instance_images\"]\n    del model_inputs[\"instance_images\"]\n\n    params.update(model_inputs)\n    print(model_inputs)\n\n    args = argparse.Namespace(**params)\n    print(args)\n\n    if args.train_text_encoder and args.pre_compute_text_embeddings:\n        raise ValueError(\n            \"`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`\"\n        )\n\n    result = {}\n\n    if not args.push_to_hub and call_inputs.get(\"dest_url\", None) == None:\n        print()\n        print(\"WARNING: Neither modelInputs.push_to_hub nor callInputs.dest_url\")\n        print(\"was given.  After training, your model won't be uploaded anywhere.\")\n        print()\n        result.update({\"no_upload\": True})\n\n    # TODO, not save at all... we're just getting it working\n    # if its a hassle, in interim, at least save to unique dir\n    if not os.path.exists(args.instance_data_dir):\n        os.mkdir(args.instance_data_dir)\n    for i, image in enumerate(instance_images):\n        image.save(args.instance_data_dir + \"/image\" + str(i) + \".png\")\n\n    subprocess.run([\"ls\", \"-l\", args.instance_data_dir])\n\n    result = result | main(args, pipeline, send_opts=send_opts)\n\n    dest_url = call_inputs.get(\"dest_url\")\n    if dest_url:\n        storage = Storage(dest_url)\n        filename = storage.path if storage.path != \"\" else args.output_dir\n        filename = filename.split(\"/\").pop()\n        print(filename)\n        if not re.search(r\"\\.\", filename):\n            filename += \".tar.zstd\"\n        print(filename)\n\n        # fp16 model timings: zip 1m20s, tar+zstd 4s and a tiny bit smaller!\n        send(\"compress\", \"start\", {}, send_opts)\n\n        # TODO, steaming upload (turns out docker disk write is super slow)\n        subprocess.run(\n            f\"tar cvf - -C {args.output_dir} . | zstd -o {filename}\",\n            shell=True,\n            check=True,  # TODO, rather don't raise and return an error in JSON\n        )\n\n        send(\"compress\", \"done\", {}, send_opts)\n        subprocess.run([\"ls\", \"-l\", filename])\n\n        send(\"upload\", \"start\", {}, send_opts)\n        upload_result = storage.upload_file(filename, filename)\n        send(\"upload\", \"done\", {}, send_opts)\n        print(upload_result)\n        os.remove(filename)\n\n    # Cleanup\n    shutil.rmtree(args.output_dir)\n    shutil.rmtree(args.class_data_dir, ignore_errors=True)\n\n    return result\n\n\n# What follows is mostly the original train_dreambooth.py\n# Any changes are marked with in comments with [DDA].\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.19.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model=str,\n    train_text_encoder=False,\n    prompt=str,\n    repo_folder=None,\n    pipeline: DiffusionPipeline = None,\n):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ninstance_prompt: {prompt}\ntags:\n- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}\n- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}- text-to-image\n- diffusers\n- dreambooth\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# DreamBooth - {repo_id}\nThis is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).\nYou can find some example images in the following. \\n\n{img_str}\nDreamBooth for the text encoder was enabled: {train_text_encoder}.\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(\n    text_encoder,\n    tokenizer,\n    unet,\n    vae,\n    args,\n    accelerator,\n    weight_dtype,\n    epoch,\n    prompt_embeds,\n    negative_prompt_embeds,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    pipeline_args = {}\n\n    if vae is not None:\n        pipeline_args[\"vae\"] = vae\n\n    if text_encoder is not None:\n        text_encoder = accelerator.unwrap_model(text_encoder)\n\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        unet=accelerator.unwrap_model(unet),\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n        **pipeline_args,\n    )\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipeline.scheduler.config:\n        variance_type = pipeline.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n        pipeline.scheduler.config, **scheduler_args\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.pre_compute_text_embeddings:\n        pipeline_args = {\n            \"prompt_embeds\": prompt_embeds,\n            \"negative_prompt_embeds\": negative_prompt_embeds,\n        }\n    else:\n        pipeline_args = {\"prompt\": args.validation_prompt}\n\n    # run inference\n    generator = (\n        None\n        if args.seed is None\n        else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    )\n    images = []\n    if args.validation_images is None:\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(\n                    **pipeline_args, num_inference_steps=25, generator=generator\n                ).images[0]\n            images.append(image)\n    else:\n        for image in args.validation_images:\n            image = Image.open(image)\n            image = pipeline(**pipeline_args, image=image, generator=generator).images[\n                0\n            ]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\n                \"validation\", np_images, epoch, dataformats=\"NHWC\"\n            )\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (\n            RobertaSeriesModelWithTransformation,\n        )\n\n        return RobertaSeriesModelWithTransformation\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        class_num=None,\n        size=512,\n        center_crop=False,\n        encoder_hidden_states=None,\n        instance_prompt_encoder_hidden_states=None,\n        tokenizer_max_length=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n        self.encoder_hidden_states = encoder_hidden_states\n        self.instance_prompt_encoder_hidden_states = (\n            instance_prompt_encoder_hidden_states\n        )\n        self.tokenizer_max_length = tokenizer_max_length\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\n                f\"Instance {self.instance_data_root} images root doesn't exists.\"\n            )\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(\n                    size, interpolation=transforms.InterpolationMode.BILINEAR\n                ),\n                transforms.CenterCrop(size)\n                if center_crop\n                else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(\n            self.instance_images_path[index % self.num_instance_images]\n        )\n        instance_image = exif_transpose(instance_image)\n\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        if self.encoder_hidden_states is not None:\n            example[\"instance_prompt_ids\"] = self.encoder_hidden_states\n        else:\n            text_inputs = tokenize_prompt(\n                self.tokenizer,\n                self.instance_prompt,\n                tokenizer_max_length=self.tokenizer_max_length,\n            )\n            example[\"instance_prompt_ids\"] = text_inputs.input_ids\n            example[\"instance_attention_mask\"] = text_inputs.attention_mask\n\n        if self.class_data_root:\n            class_image = Image.open(\n                self.class_images_path[index % self.num_class_images]\n            )\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n\n            if self.instance_prompt_encoder_hidden_states is not None:\n                example[\"class_prompt_ids\"] = self.instance_prompt_encoder_hidden_states\n            else:\n                class_text_inputs = tokenize_prompt(\n                    self.tokenizer,\n                    self.class_prompt,\n                    tokenizer_max_length=self.tokenizer_max_length,\n                )\n                example[\"class_prompt_ids\"] = class_text_inputs.input_ids\n                example[\"class_attention_mask\"] = class_text_inputs.attention_mask\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    has_attention_mask = \"instance_attention_mask\" in examples[0]\n\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n\n    if has_attention_mask:\n        attention_mask = [example[\"instance_attention_mask\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        input_ids += [example[\"class_prompt_ids\"] for example in examples]\n        pixel_values += [example[\"class_images\"] for example in examples]\n\n        if has_attention_mask:\n            attention_mask += [example[\"class_attention_mask\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n\n    if has_attention_mask:\n        attention_mask = torch.cat(attention_mask, dim=0)\n        batch[\"attention_mask\"] = attention_mask\n\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef model_has_vae(args):\n    config_file_name = os.path.join(\"vae\", AutoencoderKL.config_name)\n    if os.path.isdir(args.pretrained_model_name_or_path):\n        config_file_name = os.path.join(\n            args.pretrained_model_name_or_path, config_file_name\n        )\n        return os.path.isfile(config_file_name)\n    else:\n        files_in_repo = model_info(\n            args.pretrained_model_name_or_path, revision=args.revision\n        ).siblings\n        return any(file.rfilename == config_file_name for file in files_in_repo)\n\n\ndef tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):\n    if tokenizer_max_length is not None:\n        max_length = tokenizer_max_length\n    else:\n        max_length = tokenizer.model_max_length\n\n    text_inputs = tokenizer(\n        prompt,\n        truncation=True,\n        padding=\"max_length\",\n        max_length=max_length,\n        return_tensors=\"pt\",\n    )\n\n    return text_inputs\n\n\ndef encode_prompt(\n    text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None\n):\n    text_input_ids = input_ids.to(text_encoder.device)\n\n    if text_encoder_use_attention_mask:\n        attention_mask = attention_mask.to(text_encoder.device)\n    else:\n        attention_mask = None\n\n    prompt_embeds = text_encoder(\n        text_input_ids,\n        attention_mask=attention_mask,\n    )\n    prompt_embeds = prompt_embeds[0]\n\n    return prompt_embeds\n\n\ndef main(args, init_pipeline, send_opts):\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(\n        project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\n                \"Make sure to install wandb if you want to use it for logging during training.\"\n            )\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if (\n        args.train_text_encoder\n        and args.gradient_accumulation_steps > 1\n        and accelerator.num_processes > 1\n    ):\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            # DDA\n            # torch_dtype = (\n            #    torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            # )\n            # if args.prior_generation_precision == \"fp32\":\n            #     torch_dtype = torch.float32\n            # elif args.prior_generation_precision == \"fp16\":\n            #     torch_dtype = torch.float16\n            # elif args.prior_generation_precision == \"bf16\":\n            #     torch_dtype = torch.bfloat16\n            # DDA\n            pipeline = init_pipeline\n            pipeline.safety_checker = None\n            # pipeline = DiffusionPipeline.from_pretrained(\n            #     args.pretrained_model_name_or_path,\n            #     torch_dtype=torch_dtype,\n            #     safety_checker=None,\n            #     revision=args.revision,\n            # )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(\n                sample_dataset, batch_size=args.sample_batch_size\n            )\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            # pipeline.to(accelerator.device) # DDA already done\n\n            for example in tqdm(\n                sample_dataloader,\n                desc=\"Generating class images\",\n                disable=not accelerator.is_local_main_process,\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = (\n                        class_images_dir\n                        / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    )\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    if args.push_to_hub:\n        repo_id = create_repo(\n            repo_id=args.hub_model_id or Path(args.output_dir).name,\n            exist_ok=True,\n            token=args.hub_token,\n        ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.tokenizer_name, revision=args.revision, use_fast=False\n        )\n    elif args.pretrained_model_name_or_path:\n        tokenizer = init_pipeline.components[\"tokenizer\"]  # DDA\n        # tokenizer = AutoTokenizer.from_pretrained(\n        #     args.pretrained_model_name_or_path,\n        #     subfolder=\"tokenizer\",\n        #     revision=args.revision,\n        #     use_auth_token=args.hub_token,  # DDA\n        #     local_files_only=True,  # DDA\n        # )\n\n    # import correct text encoder class\n    # DDA\n    # text_encoder_cls = import_model_class_from_model_name_or_path(\n    #     args.pretrained_model_name_or_path,\n    #     args.revision\n    # )\n\n    # Load scheduler and models\n    # noise_scheduler = DDPMScheduler.from_pretrained(\n    #     args.pretrained_model_name_or_path,\n    #     subfolder=\"scheduler\",\n    #     use_auth_token=args.hub_token,  # DDA\n    #     local_files_only=True,  # DDA\n    # )\n\n    # text_encoder = text_encoder_cls.from_pretrained(\n    #     args.pretrained_model_name_or_path,\n    #     subfolder=\"text_encoder\",\n    #     revision=args.revision,\n    #     use_auth_token=args.hub_token,  # DDA\n    #     local_files_only=True,  # DDA\n    # )\n    # if model_has_vae(args):\n    #     vae = AutoencoderKL.from_pretrained(\n    #         args.pretrained_model_name_or_path,\n    #         subfolder=\"vae\",\n    #         revision=args.revision\n    #         use_auth_token=args.hub_token,  # DDA\n    #         local_files_only=True,  # DDA\n    #     )\n    # else:\n    #     vae = None\n    # unet = UNet2DConditionModel.from_pretrained(\n    #     args.pretrained_model_name_or_path,\n    #     subfolder=\"unet\",\n    #     revision=args.revision,\n    #     use_auth_token=args.hub_token,  # DDA\n    #     local_files_only=True,  # DDA\n    # )\n    # print(\"pipeline.disable_xformers_memory_efficient_attention()\")\n    # init_pipeline.disable_xformers_memory_efficient_attention()\n    noise_scheduler = init_pipeline.components[\"scheduler\"]  # DDA\n    text_encoder = init_pipeline.components[\"text_encoder\"]  # DDA\n    vae = init_pipeline.components[\"vae\"]  # DDA\n    unet = init_pipeline.components[\"unet\"]  # DDA\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        for model in models:\n            sub_dir = (\n                \"unet\"\n                if isinstance(model, type(accelerator.unwrap_model(unet)))\n                else \"text_encoder\"\n            )\n            model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n            # make sure to pop weight so that corresponding model is not saved again\n            weights.pop()\n\n    def load_model_hook(models, input_dir):\n        while len(models) > 0:\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(text_encoder))):\n                # load transformers style into model\n                load_model = text_encoder_cls.from_pretrained(\n                    input_dir, subfolder=\"text_encoder\"\n                )\n                model.config = load_model.config\n            else:\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(\n                    input_dir, subfolder=\"unet\"\n                )\n                model.register_to_config(**load_model.config)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # TODO, how does this affect things outside of train_dreambooth?\n    if vae is not None:\n        vae.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warn(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\n                \"xformers is not available. Make sure it is installed correctly\"\n            )\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \"Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training. copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    if (\n        args.train_text_encoder\n        and accelerator.unwrap_model(text_encoder).dtype != torch.float32\n    ):\n        raise ValueError(\n            f\"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}.\"\n            f\" {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate\n            * args.gradient_accumulation_steps\n            * args.train_batch_size\n            * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = (\n        itertools.chain(unet.parameters(), text_encoder.parameters())\n        if args.train_text_encoder\n        else unet.parameters()\n    )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    if args.pre_compute_text_embeddings:\n\n        def compute_text_embeddings(prompt):\n            with torch.no_grad():\n                text_inputs = tokenize_prompt(\n                    tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length\n                )\n                prompt_embeds = encode_prompt(\n                    text_encoder,\n                    text_inputs.input_ids,\n                    text_inputs.attention_mask,\n                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                )\n\n            return prompt_embeds\n\n        pre_computed_encoder_hidden_states = compute_text_embeddings(\n            args.instance_prompt\n        )\n        validation_prompt_negative_prompt_embeds = compute_text_embeddings(\"\")\n\n        if args.validation_prompt is not None:\n            validation_prompt_encoder_hidden_states = compute_text_embeddings(\n                args.validation_prompt\n            )\n        else:\n            validation_prompt_encoder_hidden_states = None\n\n        if args.instance_prompt is not None:\n            pre_computed_instance_prompt_encoder_hidden_states = (\n                compute_text_embeddings(args.instance_prompt)\n            )\n        else:\n            pre_computed_instance_prompt_encoder_hidden_states = None\n\n        text_encoder = None\n        tokenizer = None\n\n        gc.collect()\n        torch.cuda.empty_cache()\n    else:\n        pre_computed_encoder_hidden_states = None\n        validation_prompt_encoder_hidden_states = None\n        validation_prompt_negative_prompt_embeds = None\n        pre_computed_instance_prompt_encoder_hidden_states = None\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        class_num=args.num_class_images,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n        encoder_hidden_states=pre_computed_encoder_hidden_states,\n        instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,\n        tokenizer_max_length=args.tokenizer_max_length,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(\n        len(train_dataloader) / args.gradient_accumulation_steps\n    )\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        (\n            unet,\n            text_encoder,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        ) = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and text_encoder to device and cast to weight_dtype\n    if vae is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    if not args.train_text_encoder and text_encoder is not None:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(\n        len(train_dataloader) / args.gradient_accumulation_steps\n    )\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"dreambooth\", config=vars(args))\n\n    # Train!\n    total_batch_size = (\n        args.train_batch_size\n        * accelerator.num_processes\n        * args.gradient_accumulation_steps\n    )\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(\n        f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\"\n    )\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (\n                num_update_steps_per_epoch * args.gradient_accumulation_steps\n            )\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(\n        range(global_step, args.max_train_steps),\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    # DDA\n    send(\"training\", \"start\", {}, send_opts)\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if (\n                args.resume_from_checkpoint\n                and epoch == first_epoch\n                and step < resume_step\n            ):\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n\n                if vae is not None:\n                    # Convert images to latent space\n                    model_input = vae.encode(\n                        batch[\"pixel_values\"].to(dtype=weight_dtype)\n                    ).latent_dist.sample()\n                    model_input = model_input * vae.config.scaling_factor\n                else:\n                    model_input = pixel_values\n\n                # Sample noise that we'll add to the model input\n                if args.offset_noise:\n                    noise = torch.randn_like(model_input) + 0.1 * torch.randn(\n                        model_input.shape[0],\n                        model_input.shape[1],\n                        1,\n                        1,\n                        device=model_input.device,\n                    )\n                else:\n                    noise = torch.randn_like(model_input)\n                bsz, channels, height, width = model_input.shape\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0,\n                    noise_scheduler.config.num_train_timesteps,\n                    (bsz,),\n                    device=model_input.device,\n                )\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(\n                    model_input, noise, timesteps\n                )\n\n                # Get the text embedding for conditioning\n                if args.pre_compute_text_embeddings:\n                    encoder_hidden_states = batch[\"input_ids\"]\n                else:\n                    encoder_hidden_states = encode_prompt(\n                        text_encoder,\n                        batch[\"input_ids\"],\n                        batch[\"attention_mask\"],\n                        text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                    )\n\n                if accelerator.unwrap_model(unet).config.in_channels == channels * 2:\n                    noisy_model_input = torch.cat(\n                        [noisy_model_input, noisy_model_input], dim=1\n                    )\n\n                if args.class_labels_conditioning == \"timesteps\":\n                    class_labels = timesteps\n                else:\n                    class_labels = None\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    encoder_hidden_states,\n                    class_labels=class_labels,\n                ).sample\n\n                if model_pred.shape[1] == 6:\n                    model_pred, _ = torch.chunk(model_pred, 2, dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(\n                        f\"Unknown prediction type {noise_scheduler.config.prediction_type}\"\n                    )\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute instance loss\n                    loss = F.mse_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\"\n                    )\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(\n                        model_pred_prior.float(), target_prior.float(), reduction=\"mean\"\n                    )\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    loss = F.mse_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\"\n                    )\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(), text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [\n                                d for d in checkpoints if d.startswith(\"checkpoint\")\n                            ]\n                            checkpoints = sorted(\n                                checkpoints, key=lambda x: int(x.split(\"-\")[1])\n                            )\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = (\n                                    len(checkpoints) - args.checkpoints_total_limit + 1\n                                )\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(\n                                    f\"removing checkpoints: {', '.join(removing_checkpoints)}\"\n                                )\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(\n                                        args.output_dir, removing_checkpoint\n                                    )\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(\n                            args.output_dir, f\"checkpoint-{global_step}\"\n                        )\n                        pipeline.save_pretrained(save_path)\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    images = []\n\n                    if (\n                        args.validation_prompt is not None\n                        and global_step % args.validation_steps == 0\n                    ):\n                        images = log_validation(\n                            text_encoder,\n                            tokenizer,\n                            unet,\n                            vae,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            epoch,\n                            validation_prompt_encoder_hidden_states,\n                            validation_prompt_negative_prompt_embeds,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    send(\"training\", \"done\", {}, send_opts)  # DDA\n    if accelerator.is_main_process:\n        pipeline_args = {}\n\n        if text_encoder is not None:\n            pipeline_args[\"text_encoder\"] = accelerator.unwrap_model(text_encoder)\n\n        if args.skip_save_text_encoder:\n            pipeline_args[\"text_encoder\"] = None\n\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=accelerator.unwrap_model(unet),\n            revision=args.revision,\n            **pipeline_args,\n            local_files_only=True,  # DDA\n        )\n\n        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n        scheduler_args = {}\n\n        if \"variance_type\" in pipeline.scheduler.config:\n            variance_type = pipeline.scheduler.config.variance_type\n\n            if variance_type in [\"learned\", \"learned_range\"]:\n                variance_type = \"fixed_small\"\n\n            scheduler_args[\"variance_type\"] = variance_type\n\n        pipeline.scheduler = pipeline.scheduler.from_config(\n            pipeline.scheduler.config, **scheduler_args\n        )\n\n        pipeline.save_pretrained(args.output_dir, safe_serialization=True)\n\n        if args.push_to_hub:\n            # DDA\n            send(\"upload\", \"start\", {}, send_opts)\n\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                prompt=args.instance_prompt,\n                repo_folder=args.output_dir,\n                pipeline=pipeline,\n            )\n            # repo.push_to_hub(\n            #    commit_message=\"End of training\",\n            #   # DDA need to think about this, quite nice to not block, then could\n            #    # upload while training next request.  But, timeout will kill an unused\n            #    # process...  what else?\n            #    blocking=True,  # DDA, was: False,\n            #    auto_lfs_prune=True,\n            # )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n                # DDA\n                # https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L3379\n                # Whether or not to run this method in the background. Background jobs are run sequentially without\n                # blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects)\n                # object. Defaults to `False`.\n                # run_as_future: TODO\n            )\n\n            # DDA\n            send(\"upload\", \"done\", {}, send_opts)\n\n    accelerator.end_training()\n\n    # DDA\n    return {\"done\": True}\n"
  },
  {
    "path": "api/utils/__init__.py",
    "content": "from .storage import Storage\n"
  },
  {
    "path": "api/utils/storage/BaseStorage.py",
    "content": "import os\nimport re\nimport subprocess\nfrom abc import ABC, abstractmethod\nimport xtarfile as tarfile\n\n\nclass BaseArchive(ABC):\n    def __init__(self, path, status=None):\n        self.path = path\n        self.status = status\n\n    def updateStatus(self, type, progress):\n        if self.status:\n            self.status.update(type, progress)\n\n    def extract(self):\n        print(\"TODO\")\n\n    def splitext(self):\n        base, ext = os.path.splitext(self.path)\n        base, subext = os.path.splitext(base)\n        return base, ext, subext\n\n\nclass TarArchive(BaseArchive):\n    @staticmethod\n    def test(path):\n        return re.search(r\"\\.tar\", path)\n\n    def extract(self, dir, dry_run=False):\n        self.updateStatus(\"extract\", 0)\n        if not dir:\n            base, ext, subext = self.splitext()\n            parent_dir = os.path.dirname(self.path)\n            dir = os.path.join(parent_dir, base)\n\n        if not dry_run:\n            os.mkdir(dir)\n\n            def track_progress(tar):\n                i = 0\n                members = tar.getmembers()\n                for member in members:\n                    i += 1\n                    self.updateStatus(\"extract\", i / len(members))\n                    yield member\n\n            print(\"Extracting to \" + dir)\n            with tarfile.open(self.path, \"r\") as tar:\n                tar.extractall(path=dir, members=track_progress(tar))\n                tar.close()\n            subprocess.run([\"ls\", \"-l\", dir])\n            os.remove(self.path)\n\n        self.updateStatus(\"extract\", 1)\n        return dir  # , base, ext, subext\n\n\narchiveClasses = [TarArchive]\n\n\ndef Archive(path, **kwargs):\n    for ArchiveClass in archiveClasses:\n        if ArchiveClass.test(path):\n            return ArchiveClass(path, **kwargs)\n\n\nclass BaseStorage(ABC):\n    @staticmethod\n    @abstractmethod\n    def test(url):\n        return re.search(r\"^https?://\", url)\n\n    def __init__(self, url, **kwargs):\n        self.url = url\n        self.status = kwargs.get(\"status\", None)\n        self.query = {}\n\n    def updateStatus(self, type, progress):\n        if self.status:\n            self.status.update(type, progress)\n\n    def splitext(self):\n        base, ext = os.path.splitext(self.url)\n        base, subext = os.path.splitext(base)\n        return base, ext, subext\n\n    def get_filename(self):\n        return self.url.split(\"/\").pop()\n\n    @abstractmethod\n    def download_file(self, dest):\n        \"\"\"Download the file to `dest`\"\"\"\n        pass\n\n    def download_and_extract(self, fname, dir=None, dry_run=False):\n        \"\"\"\n        Downloads the file, and if it's an archive, extract it too.  Returns\n        the filename if not, or directory name (fname without extension) if\n        it was.\n        \"\"\"\n        if not fname:\n            fname = self.get_filename()\n\n        archive = Archive(fname, status=self.status)\n        if archive:\n            # TODO, streaming pipeline\n            self.download_file(fname)\n            return archive.extract(dir)\n        else:\n            self.download_file(fname)\n            return fname\n"
  },
  {
    "path": "api/utils/storage/BaseStorage_test.py",
    "content": "import unittest\nfrom . import Storage, S3Storage, HTTPStorage\n\n\nclass BaseStorageTest(unittest.TestCase):\n    def test_get_filename(self):\n        storage = Storage(\"http://host.com/dir/file.tar.zst\")\n        self.assertEqual(storage.get_filename(), \"file.tar.zst\")\n\n    class Download_and_extract(unittest.TestCase):\n        def test_file_only(self):\n            storage = Storage(\"http://host.com/dir/file.bin\")\n            result = storage.download_and_extract(dry_run=True)\n            self.assertEqual(result, \"file.bin\")\n\n        def test_file_archive(self):\n            storage = Storage(\"http://host.com/dir/file.tar.zst\")\n            result, base, ext, subext = storage.download_and_extract(dry_run=True)\n            self.assertEqual(result, \"file\")\n            self.assertEqual(base, \"file\")\n            self.assertEqual(ext, \"tar\")\n            self.assertEqual(subext, \"zst\")\n"
  },
  {
    "path": "api/utils/storage/HTTPStorage.py",
    "content": "import re\nimport os\nimport time\nimport requests\nfrom tqdm import tqdm\nfrom .BaseStorage import BaseStorage\nimport urllib.parse\n\n\ndef get_now():\n    return round(time.time() * 1000)\n\n\nclass HTTPStorage(BaseStorage):\n    @staticmethod\n    def test(url):\n        return re.search(r\"^https?://\", url)\n\n    def __init__(self, url, **kwargs):\n        super().__init__(url, **kwargs)\n        parts = self.url.split(\"#\", 1)\n        self.url = parts[0]\n        if len(parts) > 1:\n            self.query = urllib.parse.parse_qs(parts[1])\n\n    def upload_file(self, source, dest):\n        raise RuntimeError(\"HTTP PUT not implemented yet\")\n\n    def download_file(self, fname):\n        print(f\"Downloading {self.url} to {fname}...\")\n        resp = requests.get(self.url, stream=True)\n        total = int(resp.headers.get(\"content-length\", 0))\n        content_disposition = resp.headers.get(\"content-disposition\")\n        if content_disposition:\n            filename_search = re.search('filename=\"(.+)\"', content_disposition)\n            if filename_search:\n                self.filename = filename_search.group(1)\n        else:\n            print(\"Warning: content-disposition header is not found in the response.\")\n        # Can also replace 'file' with a io.BytesIO object\n        with open(fname, \"wb\") as file, tqdm(\n            desc=\"Downloading\",\n            total=total,\n            unit=\"iB\",\n            unit_scale=True,\n            unit_divisor=1024,\n        ) as bar:\n            total_written = 0\n            for data in resp.iter_content(chunk_size=1024):\n                size = file.write(data)\n                bar.update(size)\n                total_written += size\n                self.updateStatus(\"download\", total_written / total)\n"
  },
  {
    "path": "api/utils/storage/S3Storage.py",
    "content": "import boto3\nimport botocore\nimport re\nimport os\nimport time\nfrom tqdm import tqdm\nfrom botocore.client import Config\nfrom .BaseStorage import BaseStorage\n\nAWS_S3_ENDPOINT_URL = os.environ.get(\"AWS_S3_ENDPOINT_URL\", None)\nAWS_S3_DEFAULT_BUCKET = os.environ.get(\"AWS_S3_DEFAULT_BUCKET\", None)\nif AWS_S3_ENDPOINT_URL == \"\":\n    AWS_S3_ENDPOINT_URL = None\nif AWS_S3_DEFAULT_BUCKET == \"\":\n    AWS_S3_DEFAULT_BUCKET = None\n\n\ndef get_now():\n    return round(time.time() * 1000)\n\n\nclass S3Storage(BaseStorage):\n    def test(url):\n        return re.search(r\"^(https?\\+)?s3://\", url)\n\n    def __init__(self, url, **kwargs):\n        super().__init__(url, **kwargs)\n\n        if url.startswith(\"s3://\"):\n            url = \"https://\" + url[5:]\n        elif url.startswith(\"http+s3://\"):\n            url = \"http\" + url[7:]\n        elif url.startswith(\"https+s3://\"):\n            url = \"https\" + url[8:]\n\n        s3_dest = re.match(\n            r\"^(?P<endpoint>https?://[^/]*)(/(?P<bucket>[^/]+))?(/(?P<path>.*))?$\",\n            url,\n        ).groupdict()\n\n        if not s3_dest[\"endpoint\"] or s3_dest[\"endpoint\"].endswith(\"//\"):\n            s3_dest[\"endpoint\"] = AWS_S3_ENDPOINT_URL\n        if not s3_dest[\"bucket\"]:\n            s3_dest[\"bucket\"] = AWS_S3_DEFAULT_BUCKET\n        if not s3_dest[\"path\"] or s3_dest[\"path\"] == \"\":\n            s3_dest[\"path\"] = kwargs.get(\"default_path\", \"\")\n\n        self.endpoint_url = s3_dest[\"endpoint\"]\n        self.bucket_name = s3_dest[\"bucket\"]\n        self.path = s3_dest[\"path\"]\n\n        self._s3resource = None\n        self._s3client = None\n        self._bucket = None\n        print(\"self.endpoint_url\", self.endpoint_url)\n\n    def s3resource(self):\n        if self._s3resource:\n            return self._s3resource\n\n        self._s3 = boto3.resource(\n            \"s3\",\n            endpoint_url=self.endpoint_url,\n            config=Config(signature_version=\"s3v4\"),\n        )\n        return self._s3\n\n    def s3client(self):\n        if self._s3client:\n            return self._s3client\n\n        self._s3client = boto3.client(\n            \"s3\",\n            endpoint_url=self.endpoint_url,\n            config=Config(signature_version=\"s3v4\"),\n        )\n        return self._s3client\n\n    def bucket(self):\n        if self._bucket:\n            return self._bucket\n\n        self._bucket = self.s3resource().Bucket(self.bucket_name)\n        return self._bucket\n\n    def upload_file(self, source, dest):\n        if not dest:\n            dest = self.path\n\n        upload_start = get_now()\n        file_size = os.stat(source).st_size\n        with tqdm(total=file_size, unit=\"B\", unit_scale=True, desc=\"Uploading\") as bar:\n            total_transferred = 0\n\n            def callback(bytes_transferred):\n                nonlocal total_transferred\n                bar.update(bytes_transferred),\n                total_transferred += bytes_transferred\n                self.updateStatus(\"upload\", total_transferred / file_size)\n\n            result = self.bucket().upload_file(\n                Filename=source, Key=dest, Callback=callback\n            )\n        print(result)\n        upload_total = get_now() - upload_start\n\n        return {\"$time\": upload_total}\n\n    def download_file(self, dest):\n        if not dest:\n            dest = self.path.split(\"/\").pop()\n        print(f\"Downloading {self.url} to {dest}...\")\n        object = self.s3resource().Object(self.bucket_name, self.path)\n        object.load()\n\n        with tqdm(\n            total=object.content_length, unit=\"B\", unit_scale=True, desc=\"Downloading\"\n        ) as bar:\n            total_transferred = 0\n\n            def callback(bytes_transferred):\n                nonlocal total_transferred\n                bar.update(bytes_transferred),\n                total_transferred += bytes_transferred\n                self.updateStatus(\"download\", total_transferred / object.content_length)\n\n            object.download_file(Filename=dest, Callback=callback)\n\n    def file_exists(self):\n        # res = self.s3client().list_objects_v2(\n        #    Bucket=self.bucket_name, Prefix=self.path, MaxKeys=1\n        # )\n        # return \"Contents\" in res\n        object = self.s3resource().Object(self.bucket_name, self.path)\n        try:\n            object.load()\n        except botocore.exceptions.ClientError as error:\n            if error.response[\"Error\"][\"Code\"] == \"404\":\n                return False\n            else:\n                raise\n        return True\n"
  },
  {
    "path": "api/utils/storage/S3Storage_test.py",
    "content": "import unittest\nimport os\nfrom .S3Storage import S3Storage, AWS_S3_ENDPOINT_URL, AWS_S3_DEFAULT_BUCKET\n\n\nclass S3StorageTest(unittest.TestCase):\n    def test_endpoint_only_s3(self):\n        storage = S3Storage(\"s3://hostname:9000\")\n        self.assertEqual(storage.endpoint_url, \"https://hostname:9000\")\n        self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET)\n        self.assertEqual(storage.path, \"\")\n\n    def test_endpoint_only_http_s3(self):\n        storage = S3Storage(\"http+s3://hostname:9000\")\n        self.assertEqual(storage.endpoint_url, \"http://hostname:9000\")\n        self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET)\n        self.assertEqual(storage.path, \"\")\n\n    def test_endpoint_only_https_s3(self):\n        storage = S3Storage(\"https+s3://hostname:9000\")\n        self.assertEqual(storage.endpoint_url, \"https://hostname:9000\")\n        self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET)\n        self.assertEqual(storage.path, \"\")\n\n    def test_bucket_only(self):\n        storage = S3Storage(\"s3:///bucket\")\n        self.assertEqual(storage.endpoint_url, AWS_S3_ENDPOINT_URL)\n        self.assertEqual(storage.bucket_name, \"bucket\")\n        self.assertEqual(storage.path, \"\")\n\n    def test_url_with_bucket_and_file_only(self):\n        storage = S3Storage(\"s3:///bucket/file\")\n        self.assertEqual(storage.endpoint_url, AWS_S3_ENDPOINT_URL)\n        self.assertEqual(storage.bucket_name, \"bucket\")\n        self.assertEqual(storage.path, \"file\")\n\n    def test_full_url_with_subdirectory(self):\n        storage = S3Storage(\"s3://host/bucket/path/file\")\n        self.assertEqual(storage.endpoint_url, \"https://host\")\n        self.assertEqual(storage.bucket_name, \"bucket\")\n        self.assertEqual(storage.path, \"path/file\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "api/utils/storage/__init__.py",
    "content": "import os\nimport re\nfrom .S3Storage import S3Storage\nfrom .HTTPStorage import HTTPStorage\n\nclasses = [S3Storage, HTTPStorage]\n\n\ndef Storage(url, no_raise=False, **kwargs):\n    for StorageClass in classes:\n        if StorageClass.test(url):\n            return StorageClass(url, **kwargs)\n\n    if no_raise:\n        return None\n    else:\n        raise RuntimeError(\"No storage handler for: \" + url)\n"
  },
  {
    "path": "api/utils/storage/__init__test.py",
    "content": "import unittest\nfrom . import Storage, S3Storage, HTTPStorage\n\n\nclass StorageTest(unittest.TestCase):\n    def test_url_s3(self):\n        storage = Storage(\"s3://hostname:9000\")\n        self.assertTrue(isinstance(storage, S3Storage))\n\n    def test_url_http(self):\n        storage = Storage(\"http://hostname:9000\")\n        self.assertTrue(isinstance(storage, HTTPStorage))\n\n    def test_no_match_raise(self):\n        with self.assertRaises(RuntimeError):\n            storage = Storage(\"not_a_url\")\n\n    def test_no_match_no_raise(self):\n        storage = Storage(\"not_a_url\", no_raise=True)\n        self.assertIsNone(storage)\n"
  },
  {
    "path": "build",
    "content": "#!/bin/sh\n\n# This is my common way of building, but you can build however you like.\n# Note if you using a proxy, you need to have it running first.\n\nDOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain \\\n  docker build \\\n  -t gadicc/diffusers-api \\\n  -t gadicc/diffusers-api:test \\\n  --build-arg http_proxy=\"http://172.17.0.1:3128\" \\\n  --build-arg https_proxy=\"http://172.17.0.1:3128\" \\\n  \"$@\" .\n"
  },
  {
    "path": "docs/internal_safetensor_cache_flow.md",
    "content": "internal document to gather my thoughts\n\nRUNTIME_DOWNLOADS=1 (must be build arg)\nIMAGE_CLOUD_CACHE=\"s3://\" (can be env arg)\nCREATE_MISSING=1\n\ne.g. stabilityai/stable-diffusion-2-1-base\n\n1. Try download from IMAGE_CLOUD_CACHE\n  1. If found, use.\n  2. If not found:\n    1. Download from HuggingFace\n    2. In a subprocess:\n      1. Save with safetesors to tmp directory\n      2. Upload to IMAGE_CLOUD_CACHE\n      3. Delete original model dir, mv tmp to model dir (for next load)\n    1. Run inference with HF model.\n\nFileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/diffusers/models--stabilityai--stable-diffusion-2-1-base/refs/main'\n\n\nNVIDIA RTX Quadro 5000\n\nNO SAFETENSORS\nDownloaded in 462557 ms\nLoading model: stabilityai/stable-diffusion-2-1 (fp32)\nLoaded from disk in 3113 ms, to gpu in 1644 ms\n\nSAFETENSORS_FAST_GPU=0\nLoaded from disk in 2741 ms, to gpu in 557 ms\n\nSAFETENSORS_FAST_GPU=1\nLoaded from disk in 1153 ms, to gpu in 1495 ms\n\n\n\nNVIDIA RTX Quadro 5000 (fp16)\n\nNO SAFETENSORS\nDownloaded in 462557 ms\nLoading model: stabilityai/stable-diffusion-2-1-base (fp16)\nLoaded from disk in 2043 ms, to gpu in 1539 ms\n\nSAFETENSORS_FAST_GPU=0\n\n\nSAFETENSORS_FAST_GPU=1\nLoaded from disk in 1134 ms, to gpu in 1184 ms\n"
  },
  {
    "path": "docs/storage.md",
    "content": "# Storage\n\nMost URLs passed at build args or call args support special URLs, both to\nstore and retrieve files.\n\n**The Storage API is new and may change without notice, please keep a\ncareful look in the CHANGELOG when upgrading**.\n\n* [AWS S3](#s3)\n\n<a name=\"s3\"></a>\n## S3\n\n### Build Args\n\nSet the following **build-args**, as appropriate (through the Banana dashboard,\nby modifying the appropriate lines in the `Dockerfile`, or by specifying, e.g.\n`--build-arg AWS_ACCESS_KEY=\"XXX\"` etc.)\n\n```Dockerfile\nARG AWS_ACCESS_KEY_ID=\"XXX\"\nARG AWS_SECRET_ACCESS_KEY=\"XXX\"\nARG AWS_DEFAULT_REGION=\"us-west-1\" # best for banana\n# Optional.  ONLY SET THIS IF YOU KNOW YOU NEED TO.\n# Usually only if you're using non-Amazon S3-compatible storage.\n# If you need this, your provider will tell you exactly what\n# to put here.  Otherwise leave it blank to automatically use\n# the correct Amazon S3 endpoint.\nARG AWS_S3_ENDPOINT_URL\n```\n\n### Usage\n\nIn any URL where Storage is supported (e.g. dreambooth `dest_url`):\n\n  * `s3://endpoint/bucket/path/to/file`\n  * `s3:///bucket/file` (uses the default endpoint)\n  * `s3:///bucket` (for `dest_url`, filename will match your output model)\n  * `http+s3://...` (force http instead of https)"
  },
  {
    "path": "install.sh",
    "content": "#!/bin/sh\n\n# This entire file is no longer used but kept around for reference.\n\nif [ \"$FLASH_ATTENTION\" == \"1\" ]; then\n\n  echo \"Building with flash attention\"\n  git clone https://github.com/HazyResearch/flash-attention.git\n  cd flash-attention\n  git checkout cutlass\n  git submodule init\n  git submodule update\n  python setup.py install\n\n  cd ..\n  git clone https://github.com/HazyResearch/diffusers.git\n  pip install -e diffusers\n\nelse\n\n  echo \"Building without flash attention\"\n  git clone https://github.com/huggingface/diffusers\n  cd diffusers\n  git checkout v0.9.0\n  # 2022-11-21 [Community Pipelines] K-Diffusion Pipeline \n  # git checkout 182eb959e5efc8c77fa31394ca55376331c0ed25\n  # 2022-11-24 v_prediction (for SD 2.0)\n  # git checkout 30f6f4410487b6c1cf5be2da6c7e8fc844fb9a44\n  cd ..\n  pip install -e diffusers\n\nfi\n\n"
  },
  {
    "path": "package.json",
    "content": "{\n  \"name\": \"docker-diffusers-api\",\n  \"version\": \"0.0.1\",\n  \"main\": \"index.js\",\n  \"repository\": \"https://github.com/kiri-art/docker-diffusers-api.git\",\n  \"author\": \"Gadi Cohen <dragon@wastelands.net>\",\n  \"license\": \"MIT\",\n  \"private\": true,\n  \"devDependencies\": {\n    \"@semantic-release-plus/docker\": \"^3.1.2\",\n    \"@semantic-release/changelog\": \"^6.0.2\",\n    \"@semantic-release/git\": \"^10.0.1\",\n    \"semantic-release\": \"^19.0.5\",\n    \"semantic-release-plus\": \"^20.0.0\"\n  }\n}\n"
  },
  {
    "path": "prime.sh",
    "content": "#!/bin/sh\n\n# need to fix this.\n#download_model {'model_url': 's3://', 'model_id': 'Linaqruf/anything-v3.0', 'model_revision': 'fp16', 'hf_model_id': None}\n# {'normalized_model_id': 'models--Linaqruf--anything-v3.0--fp16'}\n#self.endpoint_url https://6fb830ebb3c8fed82a52524211d9c54e.r2.cloudflarestorage.com/diffusers\n#Downloading s3:// to /root/.cache/diffusers-api/models--Linaqruf--anything-v3.0--fp16.tar.zst...\n\n\nMODELS=(\n  # ID,precision,revision\n#  \"prompthero/openjourney-v2\"\n#  \"wd-1-4-anime_e1,,,hakurei/waifu-diffusion\"\n#  \"Linaqruf/anything-v3.0,fp16,diffusers\"\n#  \"Linaqruf/anything-v3.0,fp16,fp16\"\n#  \"stabilityai/stable-diffusion-2-1,fp16,fp16\"\n#  \"stabilityai/stable-diffusion-2-1-base,fp16,fp16\"\n#  \"stabilityai/stable-diffusion-2,fp16,fp16\"\n#  \"stabilityai/stable-diffusion-2-base,fp16,fp16\"\n#  \"CompVis/stable-diffusion-v1-4,fp16,fp16\"\n#  \"runwayml/stable-diffusion-v1-5,fp16,fp16\"\n#  \"runwayml/stable-diffusion-inpainting,fp16,fp16\"\n#  \"hakurei/waifu-diffusion,fp16,fp16\"\n#  \"hakurei/waifu-diffusion-v1-3,fp16,fp16\" # from checkpoint\n#  \"rinna/japanese-stable-diffusion\"\n#  \"OrangeMix/AbyssOrangeMix2,fp16\"\n#  \"OrangeMix/ElyOrangeMix,fp16\"\n#  \"OrangeMix/EerieOrangeMix,fp16\"\n#  \"OrangeMix/BloodOrangeMix,fp16\"\n  \"hakurei/wd-1-5-illusion-beta3,fp16,fp16\"\n  \"hakurei/wd-1-5-ink-beta3,fp16,fp16\"\n  \"hakurei/wd-1-5-mofu-beta3,fp16,fp16\"\n  \"hakurei/wd-1-5-radiance-beta3,fp16,fp16\",\n)\n\nfor MODEL_STR in ${MODELS[@]}; do\n  IFS=\",\" read -ra DATA <<<$MODEL_STR\n  MODEL_ID=${DATA[0]}\n  MODEL_PRECISION=${DATA[1]}\n  MODEL_REVISION=${DATA[2]}\n  HF_MODEL_ID=${DATA[3]}\n  python test.py txt2img \\\n    --call-arg MODEL_ID=\"$MODEL_ID\" \\\n    --call-arg HF_MODEL_ID=\"$HF_MODEL_ID\" \\\n    --call-arg MODEL_PRECISION=\"$MODEL_PRECISION\" \\\n    --call-arg MODEL_REVISION=\"$MODEL_REVISION\" \\\n    --call-arg MODEL_URL=\"s3://\" \\\n    --model-arg num_inference_steps=1\ndone\n"
  },
  {
    "path": "release.config.js",
    "content": "// https://semantic-release.gitbook.io/semantic-release/support/faq#can-i-use-semantic-release-to-publish-non-javascript-packages\nmodule.exports = {\n  \"branches\": [\"main\"],\n  \"plugins\": [\n    \"@semantic-release/commit-analyzer\",\n    \"@semantic-release/release-notes-generator\",\n    [\n      \"@semantic-release/changelog\",\n      {\n        \"changelogFile\": \"CHANGELOG.md\"\n      }\n    ],\n    [\n      \"@semantic-release/git\",\n      {\n        \"assets\": [\"CHANGELOG.md\"]\n      }\n    ],   \n    \"@semantic-release/github\",\n    [\"@semantic-release-plus/docker\", {\n      \"name\": \"gadicc/diffusers-api\"      \n      }]\n  ]\n}"
  },
  {
    "path": "requirements.txt",
    "content": "# we pin sanic==22.6.2 for compatibility with banana\nsanic==22.6.2\nsanic-ext==22.6.2\n# earlier sanics don't pin but require websockets<11.0\nwebsockets<11.0\n\n# now manually git cloned in a later step\n# diffusers==0.4.1\n# git+https://github.com/huggingface/diffusers@v0.5.1\n\ntransformers==4.33.1       # was 4.30.2 until 2023-09-08\nscipy==1.11.2              # was 1.10.0 until 2023-09-08\nrequests_futures==1.0.0\nnumpy==1.25.1              # was 1.24.1 until 2023-09-08\nscikit-image==0.21.0       # was 0.19.3 until 2023-09-08\naccelerate==0.22.0         # was 0.20.3 until 2023-09-08\ntriton==2.1.0              # was 2.0.0.post1 until 2023-09-08\nftfy==6.1.1\nspacy==3.6.1               # was 3.5.0 until 2023-09-08\nk-diffusion==0.0.16        # was 0.0.15 until 2023-09-08\nsafetensors==0.3.3         # was 0.3.1 until 2023-09-08\n\ntorch==2.0.1               # was 1.12.1 until 2023-07-19\ntorchvision==0.15.2\npytorch_lightning==2.0.8   # was 1.9.2 until 2023-09-08\n\nboto3==1.28.43             # was 1.26.57 until 2023-09-08\nbotocore==1.31.43          # was 1.29.57 until 2023-09-08\n\npytest==7.4.2              # was 7.2.1 until 2023-09-08\npytest-cov==4.1.0          # was 4.0.0 until 2023-09-08\n\ndatasets==2.14.5           # was 2.8.0 until 2023-09-08\nomegaconf==2.3.0\ntensorboard==2.14.0        # was 2.12.0 until 2023-09-08\n\nxtarfile[zstd]==0.1.0\n\nbitsandbytes==0.41.1       # was 0.40.2 until 2023-09-08\n\ninvisible-watermark==0.2.0 # released 2023-07-06\ncompel==2.0.2              # was 2.0.1 until 2023-09-08\njxlpy==0.9.2               # added 2023-09-11\n"
  },
  {
    "path": "run.sh",
    "content": "#!/bin/bash\n\ndocker run -it --rm \\\n  --gpus all  \\\n  -p 8000:8000 \\\n  -e http_proxy=\"http://172.17.0.1:3128\" \\\n  -e https_proxy=\"http://172.17.0.1:3128\" \\\n  -e REQUESTS_CA_BUNDLE=\"/usr/local/share/ca-certificates/squid-self-signed.crt\" \\\n  -e HF_AUTH_TOKEN=\"$HF_AUTH_TOKEN\" \\\n  -e AWS_ACCESS_KEY_ID=\"$AWS_ACCESS_KEY_ID\" \\\n  -e AWS_SECRET_ACCESS_KEY=\"$AWS_SECRET_ACCESS_KEY\" \\\n  -e AWS_DEFAULT_REGION=\"$AWS_DEFAULT_REGION\" \\\n  -e AWS_S3_ENDPOINT_URL=\"$AWS_S3_ENDPOINT_URL\" \\\n  -e AWS_S3_DEFAULT_BUCKET=\"$AWS_S3_DEFAULT_BUCKET\" \\\n  -v ~/root-cache:/root/.cache \\\n  \"$@\" gadicc/diffusers-api\n"
  },
  {
    "path": "run_integration_tests_on_lambda.sh",
    "content": "#!/bin/bash\n\nPAYLOAD_FILE=\"/tmp/request.json\"\n\nif [ -z \"$LAMBDA_API_KEY\" ]; then\n  echo \"No LAMBDA_API_KEY set\"\n  exit 1\nfi \n\nSSH_KEY_FILE=\"$HOME/.ssh/diffusers-api-test.pem\"\nif [ ! -f \"$SSH_KEY_FILE\" ]; then\n  curl -L $DDA_TEST_PEM > $SSH_KEY_FILE\n  chmod 600 $SSH_KEY_FILE\nfi\n\n#curl -u $LAMBDA_API_KEY: https://cloud.lambdalabs.com/api/v1/instances\n\n# TODO, find an available instance\n# https://cloud.lambdalabs.com/api/v1/instance-types\n\nlambda_run() {\n  # $1 = lambda instance-operation\n  if [ -z \"$2\" ] ; then\n    RESULT=$(\n      curl -su ${LAMBDA_API_KEY}: \\\n        https://cloud.lambdalabs.com/api/v1/$1 \\\n        -H \"Content-Type: application/json\"\n    )\n  else\n    RESULT=$(\n      curl -su ${LAMBDA_API_KEY}: \\\n        https://cloud.lambdalabs.com/api/v1/$1 \\\n        -d @$2 -H \"Content-Type: application/json\"\n    )\n  fi\n\n  if [ $? -eq 1 ]; then\n    echo \"curl failed\"\n    exit 1\n  fi\n\n  if [ \"$RESULT\" != \"\" ]; then\n    echo $RESULT | jq -e .error >& /dev/null\n    if [ $? -eq 0 ]; then\n      echo \"lambda error\"\n      echo $RESULT\n      exit 1\n    fi\n  fi\n}\n\ninstance_create() {\n  echo -n \"Creating instance...\"\n  local RESULT=\"\"\n  cat > $PAYLOAD_FILE << __END__\n  {\n    \"region_name\": \"us-west-1\",\n    \"instance_type_name\": \"gpu_1x_a10\",\n    \"ssh_key_names\": [\n      \"diffusers-api-test\"\n    ],\n    \"file_system_names\": [],\n    \"quantity\": 1\n  }\n__END__\n\n  lambda_run \"instance-operations/launch\" $PAYLOAD_FILE\n  # echo $RESULT\n  INSTANCE_ID=$(echo $RESULT | jq -re '.data.instance_ids[0]')\n  echo \"$INSTANCE_ID\"\n  if [ $? -eq 1 ]; then\n    echo \"jq failed\"\n    exit 1\n  fi\n}\n\ninstance_terminate() {\n  # $1 = INSTANCE_ID\n  echo \"Terminating instance $1\"\n  cat > $PAYLOAD_FILE << __END__\n  {\n    \"instance_ids\": [\n      \"$1\"\n    ]\n  }\n__END__\n  lambda_run \"instance-operations/terminate\" $PAYLOAD_FILE\n  echo $RESULT\n}\n\ndeclare -A IPS\ninstance_wait() {\n  INSTANCE_ID=\"$1\"\n  echo -n \"Waiting for $INSTANCE_ID\"\n  STATUS=\"\"\n  LAST_STATUS=\"\"\n  while [ \"$STATUS\" != \"active\" ] ; do\n    echo -n \".\"\n    lambda_run \"instances/$INSTANCE_ID\"\n    STATUS=$(echo $RESULT | jq -r '.data.status')\n    if [ \"$STATUS\" != \"$LAST_STATUS\" ]; then\n      # echo $RESULT\n      # echo STATUS $STATUS\n      LAST_STATUS=$STATUS\n    fi\n    sleep 1\n  done\n  echo\n\n  IP=$(echo $RESULT | jq -r '.data.ip')\n  echo STATUS $STATUS\n  echo IP $IP\n  IPS[\"$INSTANCE_ID\"]=$IP\n}\n\ninstance_run_script() {\n  INSTANCE_ID=\"$1\"\n  SCRIPT=\"$2\"\n  DIRECTORY=\"${3:-'.'}\"\n  IP=${IPS[\"$INSTANCE_ID\"]}\n\n  echo \"instance_run_script $1 $2 $3\"\n  ssh -i $SSH_KEY_FILE ubuntu@$IP \"cd $DIRECTORY && bash -s\" < $SCRIPT\n  return $?\n}\n\ninstance_run_command() {\n  INSTANCE_ID=\"$1\"\n  CMD=\"$2\"\n  DIRECTORY=\"${3:-'.'}\"\n  IP=${IPS[\"$INSTANCE_ID\"]}\n\n  echo \"instance_run_command $1 $2\"\n  ssh -i $SSH_KEY_FILE -o StrictHostKeyChecking=accept-new ubuntu@$IP \"cd $DIRECTORY && $CMD\"\n  return $?\n}\n\ninstance_rsync() {\n  INSTANCE_ID=\"$1\"\n  SOURCE=\"$2\"\n  DEST=\"$3\"\n  IP=${IPS[\"$INSTANCE_ID\"]}\n\n  echo \"instance_rsync $1 $2 $3\"\n  rsync -avzPe \"ssh -i $SSH_KEY_FILE -o StrictHostKeyChecking=accept-new\" --filter=':- .gitignore' --exclude=\".git\" $SOURCE ubuntu@$IP:$DEST\n  return $?\n}\n\n# Image Method 3, preparation (TODO, arg to specify which method)\ndocker build -t gadicc/diffusers-api:test .\ndocker push gadicc/diffusers-api:test\n\ninstance_create\n# INSTANCE_ID=\"913e06f669bf4e799c6223801eb82f40\"\n\ninstance_wait $INSTANCE_ID\n\ncommands() {\n  instance_run_command $INSTANCE_ID \"echo 'export HF_AUTH_TOKEN=\\\"$HF_AUTH_TOKEN\\\"' >> ~/.bashrc\"\n\n  # Whether to build or just for test scripts, lets transfer this checkout.\n  instance_rsync $INSTANCE_ID . docker-diffusers-api\n\n  instance_run_command $INSTANCE_ID \"sudo apt-get update\"\n  if [ $? -eq 1 ]; then return 1 ; fi\n  instance_run_command $INSTANCE_ID \"sudo apt install -yqq python3.9\"\n  if [ $? -eq 1 ]; then return 1 ; fi\n  instance_run_command $INSTANCE_ID \"python3.9 -m pip install -r docker-diffusers-api/tests/integration/requirements.txt\"\n  if [ $? -eq 1 ]; then return 1 ; fi\n  instance_run_command $INSTANCE_ID \"sudo usermod -aG docker ubuntu\"\n  if [ $? -eq 1 ]; then return 1 ; fi\n\n  # Image Method 1: Transfer entire image\n  # This turned out to be way too slow, quicker to rebuild on lambda\n  # Longer term, I guess we need our own container registry.\n  # echo \"Saving and transferring docker image to Lambda...\"\n  # IP=${IPS[\"$INSTANCE_ID\"]}\n  # docker save gadicc/diffusers-api:latest \\\n  #   | xz \\\n  #   | pv \\\n  #   | ssh -i $SSH_KEY_FILE ubuntu@$IP docker load\n  # if [ $? -eq 1 ]; then return 1 ; fi\n\n  # Image Method 2: Build on LambdaLabs\n  #if [ $? -eq 1 ]; then return 1 ; fi\n  #instance_run_command $INSTANCE_ID \"docker build -t gadicc/diffusers-api .\" docker-diffusers-api\n\n  # Image Method 3: Just upload new layers; Lambda has fast downloads from registry\n  # At start of script we have docker build/push.  Now let's pull:\n  instance_run_command $INSTANCE_ID \"docker pull gadicc/diffusers-api:test\"\n\n  # instance_run_script $INSTANCE_ID run_integration_tests.sh docker-diffusers-api\n  instance_run_command $INSTANCE_ID \"export HF_AUTH_TOKEN=\\\"$HF_AUTH_TOKEN\\\" && python3.9 -m pytest -s tests/integration\" docker-diffusers-api\n}\n\ncommands\nRETURN_VALUE=$?\n\ninstance_terminate $INSTANCE_ID\n\nexit $RETURN_VALUE"
  },
  {
    "path": "scripts/devContainerPostCreate.sh",
    "content": "#!/bin/bash\n\n# devcontainer.json postCreateCommand\n\necho\necho Initialize conda bindings for bash\nconda init bash\n\necho Activating\nsource /opt/conda/bin/activate base\n\necho Installing dev dependencies\npip install watchdog\n"
  },
  {
    "path": "scripts/devContainerServer.sh",
    "content": "#!/bin/bash\n\nsource /opt/conda/bin/activate base\n\nln -sf /api/diffusers .\n\nwatchmedo auto-restart --recursive -d api python api/server.py"
  },
  {
    "path": "scripts/patchmatch-setup.sh",
    "content": "#!/bin/sh\n\nif [ \"$USE_PATCHMATCH\" != \"1\" ]; then\n  echo \"Skipping PyPatchMatch install because USE_PATCHMATCH=$USE_PATCHMATCH\"\n  mkdir PyPatchMatch\n  touch PyPatchMatch/patch_match.py\n  exit\nfi\n\necho \"Installing PyPatchMatch because USE_PATCHMATCH=$USE_PATCHMATCH\"\napt-get install -yqq libopencv-dev python3-opencv > /dev/null\ngit clone https://github.com/lkwq007/PyPatchMatch\ncd PyPatchMatch\ngit checkout 0ae9b8bbdc83f84214405376f13a2056568897fb\nsed -i '0,/if os.name!=\"nt\":/s//if False:/' patch_match.py\nmake\n"
  },
  {
    "path": "scripts/permutations.yaml",
    "content": "list:\n\n  - name: sd-v1-5\n    HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n    MODEL_ID: runwayml/stable-diffusion-v1-5\n    PIPELINE: ALL\n\n  - name: sd-v1-4\n    HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n    MODEL_ID: CompVis/stable-diffusion-v1-4\n    PIPELINE: ALL\n\n  - name: sd-inpaint\n    HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n    MODEL_ID: runwayml/stable-diffusion-inpainting\n    PIPELINE: StableDiffusionInpaintPipeline\n\n  - name: sd-waifu\n    HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n    MODEL_ID: hakurei/waifu-diffusion\n    PIPELINE: ALL\n\n  - name: sd-waifu-v1-3\n    HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n    MODEL_ID: hakurei/waifu-diffusion-v1-3\n    CHECKPOINT_URL: https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt\n    PIPELINE: ALL\n\n  - name: sd-jp\n    HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n    MODEL_ID: rinna/japanese-stable-diffusion\n    PIPELINE: ALL\n"
  },
  {
    "path": "scripts/permute.sh",
    "content": "#!/usr/bin/env bash\n\n# Run this in banana-sd-base's PARENT directory.\n# Modify the below first per your preferences\n\n# Requires `yq` from https://github.com/mikefarah/yq\n# Note, there are two yqs.  In Archlinux the package is \"go-yq\".\n\nif [ -z \"$1\" ]; then \n  echo \"Using 'scripts/permutations.yaml' as default INFILE\"\n  echo \"You can also run: permutate.sh MY_INFILE\"\n  INFILE='scripts/permutations.yaml'\nelse\n  INFILE=$1\nfi\n\nif [ -z \"$TARGET_REPO_BASE\" ]; then\n  TARGET_REPO_BASE=\"git@github.com:kiri-art\"\n  echo 'No TARGET_REPO_BASE found, using \"$TARGET_REPO_BASE\"'\nfi\n\npermutations=$(yq e -o=j -I=0 '.list[]' $INFILE)\n\n# Needed for ! expansion in cp command further down.\nshopt -s extglob\n# Include dot files in expansion for .git .gitignore\nshopt -s dotglob\n\nCOUNTER=0\ndeclare -A vars\n\nmkdir -p permutations\n\nwhile IFS=\"=\" read permutation; do\n  # e.g. Permutation #1: banana-sd-txt2img\n  NAME=$(echo \"$permutation\" | yq e '.name')\n  COUNTER=$[$COUNTER + 1]\n  echo\n  echo \"Permutation #$COUNTER: $NAME\"\n\n  while IFS=\"=\" read -r key value\n  do\n    if [ \"$key\" != \"name\" ]; then\n      if [ \"${value:0:1}\" == \"$\" ]; then\n        # For e.g. \"$HF_AUTH_TOKEN\", expand from environment\n        value=\"${value:1}\"\n        vars[$key]=${!value}\n      else\n        vars[$key]=$value;\n      fi\n    fi\n  done < <(echo $permutation | yq e 'to_entries | .[] | (.key + \"=\" + .value)')\n\n  if [ -d \"permutations/$NAME\" ]; then \n    echo \"./permutations/$NAME already exists, skipping...\"\n    echo \"Run 'rm -rf permutations/$NAME' first to remake this permutation\"\n    echo \"In a later release, we'll merge updates in this case.\"\n    continue\n  fi\n\n  # echo \"mkdir permutations/$NAME\"\n  mkdir permutations/$NAME\n  # echo 'cp -a ./!(permutations|scripts|root-cache) permutations/$NAME'\n  cp -a ./!(permutations|scripts|root-cache) permutations/$NAME\n  # echo cd permutations/$NAME\n  cd permutations/$NAME\n\n  echo \"Substituting variables in Dockerfile\"\n  for key in \"${!vars[@]}\"; do\n    value=\"${vars[$key]}\"\n    sed -i \"s@^ARG $key.*\\$@ARG $key=\\\"$value\\\"@\" Dockerfile\n  done\n\n  diffusers=${vars[diffusers]}\n  if [ \"$diffusers\" ]; then\n    echo \"Replacing diffusers with $diffusers\"\n    echo \"!!! NOT DONE YET !!!\"\n  fi\n\n  mkdir root-cache\n  touch root-cache/non-empty-directory\n  git add root-cache\n\n  git remote rm origin\n  git remote add upstream git@github.com:kiri-art/docker-diffusers-api.git\n  git remote add origin $TARGET_REPO_BASE/$NAME.git\n\n  echo git commit -a -m \"$NAME permutation variables\"\n  git commit -a -m \"$NAME permutation variables\"\n\n  # echo \"cd ../..\"\n  cd ../..\n  echo\ndone <<EOF\n$permutations\nEOF\n"
  },
  {
    "path": "test.py",
    "content": "# This file is used to verify your http server acts as expected\n# Run it with `python3 test.py``\n\nimport requests\nimport base64\nimport os\nimport json\nimport sys\nimport time\nimport datetime\nimport argparse\nimport distutils\nfrom uuid import uuid4\nfrom io import BytesIO\nfrom PIL import Image\nfrom pathlib import Path, PosixPath\n\n# path = os.path.dirname(os.path.realpath(sys.argv[0]))\npath = \".\"\nTESTS = path + os.sep + \"tests\"\nFIXTURES = TESTS + os.sep + \"fixtures\"\nOUTPUT = TESTS + os.sep + \"output\"\nTEST_URL = os.environ.get(\"TEST_URL\", \"http://localhost:8000/\")\nBANANA_API_URL = os.environ.get(\"BANANA_API_URL\", \"https://api.banana.dev\")\nPath(OUTPUT).mkdir(parents=True, exist_ok=True)\n\n\ndef b64encode_file(filename: str):\n    path = (\n        filename\n        if isinstance(filename, PosixPath)\n        else os.path.join(FIXTURES, filename)\n    )\n    with open(path, \"rb\") as file:\n        return base64.b64encode(file.read()).decode(\"ascii\")\n\n\ndef output_path(filename: str):\n    return os.path.join(OUTPUT, filename)\n\n\n# https://stackoverflow.com/a/1094933/1839099\ndef sizeof_fmt(num, suffix=\"B\"):\n    for unit in [\"\", \"Ki\", \"Mi\", \"Gi\", \"Ti\", \"Pi\", \"Ei\", \"Zi\"]:\n        if abs(num) < 1024.0:\n            return f\"{num:3.1f}{unit}{suffix}\"\n        num /= 1024.0\n    return f\"{num:.1f}Yi{suffix}\"\n\n\ndef decode_and_save(image_byte_string: str, name: str):\n    image_encoded = image_byte_string.encode(\"utf-8\")\n    image_bytes = BytesIO(base64.b64decode(image_encoded))\n    image = Image.open(image_bytes)\n    fp = output_path(name + \".png\")\n    image.save(fp)\n    print(\"Saved \" + fp)\n    size_formatted = sizeof_fmt(os.path.getsize(fp))\n\n    return (\n        f\"[{image.width}x{image.height} {image.format} image, {size_formatted} bytes]\"\n    )\n\n\nall_tests = {}\n\n\ndef test(name, inputs):\n    global all_tests\n    all_tests.update({name: inputs})\n\n\ndef runTest(name, args, extraCallInputs, extraModelInputs):\n    origInputs = all_tests.get(name)\n    inputs = {\n        \"modelInputs\": origInputs.get(\"modelInputs\", {}).copy(),\n        \"callInputs\": origInputs.get(\"callInputs\", {}).copy(),\n    }\n    inputs.get(\"callInputs\").update(extraCallInputs)\n    inputs.get(\"modelInputs\").update(extraModelInputs)\n\n    print(\"Running test: \" + name)\n\n    inputs_to_log = {\n        \"modelInputs\": inputs[\"modelInputs\"].copy(),\n        \"callInputs\": inputs[\"callInputs\"].copy(),\n    }\n    model_inputs_to_log = inputs_to_log[\"modelInputs\"]\n\n    for key in [\"init_image\", \"image\"]:\n        if key in model_inputs_to_log:\n            model_inputs_to_log[key] = \"[image]\"\n\n    instance_images = model_inputs_to_log.get(\"instance_images\", None)\n    if instance_images:\n        model_inputs_to_log[\"instance_images\"] = f\"[Array({len(instance_images)})]\"\n\n    print(json.dumps(inputs_to_log, indent=4))\n    print()\n\n    start = time.time()\n    if args.get(\"banana\", None):\n        BANANA_API_KEY = os.getenv(\"BANANA_API_KEY\")\n        BANANA_MODEL_KEY = os.getenv(\"BANANA_MODEL_KEY\")\n        if BANANA_MODEL_KEY == None or BANANA_API_KEY == None:\n            print(\"Error: BANANA_API_KEY or BANANA_MODEL_KEY not set, aborting...\")\n            sys.exit(1)\n\n        payload = {\n            \"id\": str(uuid4()),\n            \"created\": int(time.time()),\n            \"apiKey\": BANANA_API_KEY,\n            \"modelKey\": BANANA_MODEL_KEY,\n            \"modelInputs\": inputs,\n            \"startOnly\": False,\n        }\n\n        response = requests.post(f\"{BANANA_API_URL}/start/v4/\", json=payload)\n\n        result = response.json()\n        callID = result.get(\"callID\")\n\n        if result.get(\"finished\", None) == False:\n            while result.get(\n                \"message\", None\n            ) != \"success\" and not \"error\" in result.get(\"message\", None):\n                secondsSinceStart = time.time() - start\n                print(str(datetime.datetime.now()) + f\": t+{secondsSinceStart:.1f}s\")\n                print(json.dumps(result, indent=4))\n                print\n                payload = {\n                    \"id\": str(uuid4()),\n                    \"created\": int(time.time()),\n                    \"longPoll\": True,\n                    \"apiKey\": BANANA_API_KEY,\n                    \"callID\": callID,\n                }\n                response = requests.post(f\"{BANANA_API_URL}/check/v4/\", json=payload)\n                result = response.json()\n\n        modelOutputs = result.get(\"modelOutputs\", None)\n        if modelOutputs == None:\n            finish = time.time() - start\n            print(f\"Request took {finish:.1f}s\")\n            print(result)\n            return\n        result = modelOutputs[0]\n    elif args.get(\"runpod\", None):\n        RUNPOD_API_URL = \"https://api.runpod.ai/v1/\"\n        RUNPOD_API_KEY = os.getenv(\"RUNPOD_API_KEY\")\n        RUNPOD_MODEL_KEY = os.getenv(\"RUNPOD_MODEL_KEY\")\n        if not (RUNPOD_API_KEY and RUNPOD_MODEL_KEY):\n            print(\"Error: RUNPOD_API_KEY or RUNPOD_MODEL_KEY not set, aborting...\")\n            sys.exit(1)\n\n        url_base = RUNPOD_API_URL + RUNPOD_MODEL_KEY\n\n        payload = {\n            \"input\": inputs,\n        }\n        print(url_base + \"/run\")\n        response = requests.post(\n            url_base + \"/run\",\n            json=payload,\n            headers={\"Authorization\": \"Bearer \" + RUNPOD_API_KEY},\n        )\n\n        if response.status_code != 200:\n            print(\"Unexpected HTTP response code: \" + str(response.status_code))\n            sys.exit(1)\n\n        print(response)\n        result = response.json()\n        print(result)\n\n        id = result[\"id\"]\n\n        while result[\"status\"] != \"COMPLETED\":\n            time.sleep(1)\n            response = requests.get(\n                f\"{url_base}/status/{id}\",\n                headers={\"Authorization\": \"Bearer \" + RUNPOD_API_KEY},\n            )\n            result = response.json()\n\n        result = result[\"output\"]\n\n    else:\n        test_url = args.get(\"test_url\", None) or TEST_URL\n        call_inputs = inputs[\"callInputs\"]\n        stream_events = call_inputs and call_inputs.get(\"streamEvents\", 0) != 0\n        print({\"stream_events\": stream_events})\n        if stream_events:\n            result = None\n            response = requests.post(test_url, json=inputs, stream=True)\n            for line in response.iter_lines():\n                if line:\n                    result = json.loads(line)\n                    if not result.get(\"$timings\", None):\n                        print(result)\n        else:\n            response = requests.post(test_url, json=inputs)\n            try:\n                result = response.json()\n            except requests.exceptions.JSONDecodeError as error:\n                print(error)\n                print(response.text)\n                sys.exit(1)\n\n    finish = time.time() - start\n    timings = result.get(\"$timings\")\n\n    if timings:\n        timings_str = json.dumps(\n            dict(\n                map(\n                    lambda item: (\n                        item[0],\n                        f\"{item[1]/1000/60:.1f}m\"\n                        if item[1] > 60000\n                        else f\"{item[1]/1000:.1f}s\"\n                        if item[1] > 1000\n                        else str(item[1]) + \"ms\",\n                    ),\n                    timings.items(),\n                )\n            )\n        ).replace('\"', \"\")[1:-1]\n        print(f\"Request took {finish:.1f}s ({timings_str})\")\n    else:\n        print(f\"Request took {finish:.1f}s\")\n\n    if (\n        result.get(\"images_base64\", None) == None\n        and result.get(\"image_base64\", None) == None\n    ):\n        error = result.get(\"$error\", None)\n        if error:\n            code = error.get(\"code\", None)\n            name = error.get(\"name\", None)\n            message = error.get(\"message\", None)\n            stack = error.get(\"stack\", None)\n            if code and name and message and stack:\n                print()\n                title = f\"Exception {code} on container:\"\n                print(title)\n                print(\"-\" * len(title))\n                # print(f'{name}(\"{message}\")') - stack includes it.\n                print(stack)\n                return\n\n        print(json.dumps(result, indent=4))\n        print()\n        return result\n\n    images_base64 = result.get(\"images_base64\", None)\n    if images_base64:\n        for idx, image_byte_string in enumerate(images_base64):\n            images_base64[idx] = decode_and_save(image_byte_string, f\"{name}_{idx}\")\n    else:\n        result[\"image_base64\"] = decode_and_save(result[\"image_base64\"], name)\n\n    print()\n    print(json.dumps(result, indent=4))\n    print()\n    return result\n\n\ntest(\n    \"txt2img\",\n    {\n        \"modelInputs\": {\n            \"prompt\": \"realistic field of grass\",\n            \"num_inference_steps\": 20,\n        },\n        \"callInputs\": {\n            # \"MODEL_ID\": \"<override_default>\",  # (default)\n            # \"PIPELINE\": \"StableDiffusionPipeline\",  # (default)\n            # \"SCHEDULER\": \"DPMSolverMultistepScheduler\",  # (default)\n            # \"xformers_memory_efficient_attention\": False,  # (default)\n        },\n    },\n)\n\n# multiple images\ntest(\n    \"txt2img-multiple\",\n    {\n        \"modelInputs\": {\n            \"prompt\": \"realistic field of grass\",\n            \"num_images_per_prompt\": 2,\n        }\n    },\n)\n\n\ntest(\n    \"img2img\",\n    {\n        \"modelInputs\": {\n            \"prompt\": \"A fantasy landscape, trending on artstation\",\n            \"image\": b64encode_file(\"sketch-mountains-input.jpg\"),\n        },\n        \"callInputs\": {\n            \"PIPELINE\": \"StableDiffusionImg2ImgPipeline\",\n        },\n    },\n)\n\ntest(\n    \"inpaint-v1-4\",\n    {\n        \"modelInputs\": {\n            \"prompt\": \"a cat sitting on a bench\",\n            \"image\": b64encode_file(\"overture-creations-5sI6fQgYIuo.png\"),\n            \"mask_image\": b64encode_file(\"overture-creations-5sI6fQgYIuo_mask.png\"),\n        },\n        \"callInputs\": {\n            \"MODEL_ID\": \"CompVis/stable-diffusion-v1-4\",\n            \"PIPELINE\": \"StableDiffusionInpaintPipelineLegacy\",\n            \"SCHEDULER\": \"DDIMScheduler\",  # Note, as of diffusers 0.3.0, no LMS yet\n        },\n    },\n)\n\ntest(\n    \"inpaint-sd\",\n    {\n        \"modelInputs\": {\n            \"prompt\": \"a cat sitting on a bench\",\n            \"image\": b64encode_file(\"overture-creations-5sI6fQgYIuo.png\"),\n            \"mask_image\": b64encode_file(\"overture-creations-5sI6fQgYIuo_mask.png\"),\n        },\n        \"callInputs\": {\n            \"MODEL_ID\": \"runwayml/stable-diffusion-inpainting\",\n            \"PIPELINE\": \"StableDiffusionInpaintPipeline\",\n            \"SCHEDULER\": \"DDIMScheduler\",  # Note, as of diffusers 0.3.0, no LMS yet\n        },\n    },\n)\n\ntest(\n    \"checkpoint\",\n    {\n        \"modelInputs\": {\n            \"prompt\": \"1girl\",\n        },\n        \"callInputs\": {\n            \"MODEL_ID\": \"hakurei/waifu-diffusion-v1-3\",\n            \"MODEL_URL\": \"s3://\",\n            \"CHECKPOINT_URL\": \"http://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt\",\n        },\n    },\n)\n\nif os.getenv(\"USE_PATCHMATCH\"):\n    test(\n        \"outpaint\",\n        {\n            \"modelInputs\": {\n                \"prompt\": \"girl with a pearl earing standing in a big room\",\n                \"image\": b64encode_file(\"girl_with_pearl_earing_outpainting_in.png\"),\n            },\n            \"callInputs\": {\n                \"MODEL_ID\": \"CompVis/stable-diffusion-v1-4\",\n                \"PIPELINE\": \"StableDiffusionInpaintPipelineLegacy\",\n                \"SCHEDULER\": \"DDIMScheduler\",  # Note, as of diffusers 0.3.0, no LMS yet\n                \"FILL_MODE\": \"patchmatch\",\n            },\n        },\n    )\n\n# Actually we just want this to be a non-default test?\nif True or os.getenv(\"USE_DREAMBOOTH\"):\n    test(\n        \"dreambooth\",\n        # If you're calling from the command line, don't forget to a\n        # specify a destination if you want your fine-tuned model to\n        # be uploaded somewhere at the end.\n        {\n            \"modelInputs\": {\n                \"instance_prompt\": \"a photo of sks dog\",\n                \"instance_images\": list(\n                    map(\n                        b64encode_file,\n                        list(Path(\"tests/fixtures/dreambooth\").iterdir()),\n                    )\n                ),\n                # Option 1: upload to HuggingFace (see notes below)\n                # Make sure your HF API token has read/write access.\n                # \"hub_model_id\": \"huggingFaceUsername/targetModelName\",\n                # \"push_to_hub\": True,\n            },\n            \"callInputs\": {\n                \"train\": \"dreambooth\",\n                # Option 2: store on S3.  Note the **s3:///* (x3).  See notes below.\n                # \"dest_url\": \"s3:///bucket/filename.tar.zst\".\n            },\n        },\n    )\n\n\ndef main(tests_to_run, args, extraCallInputs, extraModelInputs):\n    invalid_tests = []\n    for test in tests_to_run:\n        if all_tests.get(test, None) == None:\n            invalid_tests.append(test)\n\n    if len(invalid_tests) > 0:\n        print(\"No such tests: \" + \", \".join(invalid_tests))\n        exit(1)\n\n    for test in tests_to_run:\n        runTest(test, args, extraCallInputs, extraModelInputs)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--banana\", required=False, action=\"store_true\")\n    parser.add_argument(\"--runpod\", required=False, action=\"store_true\")\n    parser.add_argument(\n        \"--xmfe\",\n        required=False,\n        default=None,\n        type=lambda x: bool(distutils.util.strtobool(x)),\n    )\n    parser.add_argument(\"--scheduler\", required=False, type=str)\n    parser.add_argument(\"--call-arg\", action=\"append\", type=str)\n    parser.add_argument(\"--model-arg\", action=\"append\", type=str)\n\n    args, tests_to_run = parser.parse_known_args()\n\n    call_inputs = {}\n    model_inputs = {}\n\n    if args.call_arg:\n        for arg in args.call_arg:\n            name, value = arg.split(\"=\", 1)\n            if value.lower() == \"true\":\n                value = True\n            elif value.lower() == \"false\":\n                value = False\n            elif value.isdigit():\n                value = int(value)\n            elif value.replace(\".\", \"\", 1).isdigit():\n                value = float(value)\n            call_inputs.update({name: value})\n\n    if args.model_arg:\n        for arg in args.model_arg:\n            name, value = arg.split(\"=\", 1)\n            if value.lower() == \"true\":\n                value = True\n            elif value.lower() == \"false\":\n                value = False\n            elif value.isdigit():\n                value = int(value)\n            elif value.replace(\".\", \"\", 1).isdigit():\n                value = float(value)\n            model_inputs.update({name: value})\n\n    if args.xmfe != None:\n        call_inputs.update({\"xformers_memory_efficient_attention\": args.xmfe})\n    if args.scheduler:\n        call_inputs.update({\"SCHEDULER\": args.scheduler})\n\n    if len(tests_to_run) < 1:\n        print(\n            \"Usage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc]\"\n        )\n        sys.exit()\n    elif len(tests_to_run) == 1 and (\n        tests_to_run[0] == \"ALL\" or tests_to_run[0] == \"all\"\n    ):\n        tests_to_run = list(all_tests.keys())\n\n    main(\n        tests_to_run,\n        vars(args),\n        extraCallInputs=call_inputs,\n        extraModelInputs=model_inputs,\n    )\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integration/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integration/conftest.py",
    "content": "import pytest\nimport os\nfrom .lib import startContainer, get_free_port, DOCKER_GW_IP\n\n\n@pytest.fixture(autouse=True, scope=\"session\")\ndef my_fixture():\n    # setup_stuff\n    print(\"session start\")\n\n    # newCache = not os.getenv(\"DDA_https_proxy\")\n    newCache = False\n\n    if newCache:\n        squid_port = get_free_port()\n        http_port = get_free_port()\n        container, stop = startContainer(\n            \"gadicc/squid-ssl-zero\",\n            ports={3128: squid_port, 3129: http_port},\n        )\n        os.environ[\"DDA_http_proxy\"] = f\"http://{DOCKER_GW_IP}:{squid_port}\"\n        os.environ[\"DDA_https_proxy\"] = os.environ[\"DDA_http_proxy\"]\n        # TODO, code in getDDA to download cert\n\n    yield\n    # teardown_stuff\n    print(\"session end\")\n    if newCache:\n        stop()\n"
  },
  {
    "path": "tests/integration/lib.py",
    "content": "import pytest\nimport docker\nimport atexit\nimport time\nimport boto3\nimport os\nimport requests\nimport socket\nimport asyncio\nimport sys\nimport subprocess\nimport selectors\nfrom threading import Thread\nfrom argparse import Namespace\n\nAWS_S3_DEFAULT_BUCKET = os.environ.get(\"AWS_S3_DEFAULT_BUCKET\", \"test\")\nDOCKER_GW_IP = \"172.17.0.1\"  # will override below if found\n\nmyContainers = list()\ndockerClient = docker.DockerClient(\n    base_url=\"unix://var/run/docker.sock\", version=\"auto\"\n)\nfor network in dockerClient.networks.list():\n    if network.attrs[\"Scope\"] == \"local\" and network.attrs[\"Driver\"] == \"bridge\":\n        DOCKER_GW_IP = network.attrs[\"IPAM\"][\"Config\"][0][\"Gateway\"]\n        break\n\n# # https://stackoverflow.com/a/53255955/1839099\n# def fire_and_forget(f):\n#     def wrapped(*args, **kwargs):\n#         return asyncio.get_event_loop().run_in_executor(None, f, *args, *kwargs)\n#     return wrapped\n#\n# @fire_and_forget\n# def log_streamer(container):\n#   for line in container.logs(stream=True):\n#     print(line.decode(), end=\"\")\n\n\ndef log_streamer(container, name=None):\n    \"\"\"\n    Streams logs to stdout/stderr.\n    Order is not guaranteed (have tried 3 different methods)\n    \"\"\"\n    # Method 1: pipe streams directly -- even this doesn't guarantee order\n    # Method 2: threads + readline\n    # Method 3: selectors + read1\n    method = 1\n\n    if method == 1:\n        kwargs = {\n            \"stdout\": sys.stdout,\n            \"stderr\": sys.stderr,\n        }\n    elif method == 2:\n        kwargs = {\n            \"stdout\": subprocess.PIPE,\n            \"stderr\": subprocess.PIPE,\n            \"bufsize\": 1,\n            \"universal_newlines\": True,\n        }\n    elif method == 3:\n        kwargs = {\n            \"stdout\": subprocess.PIPE,\n            \"stderr\": subprocess.PIPE,\n            \"bufsize\": 1,\n        }\n\n    prefix = f\"[{name or container.id[:7]}] \"\n    print(prefix + \"== Streaming logs (stdout/stderr order not guaranteed): ==\")\n\n    sp = subprocess.Popen([\"docker\", \"logs\", \"-f\", container.id], **kwargs)\n\n    if method == 2:\n\n        def reader(pipe):\n            while True:\n                read = pipe.readline()\n                if read == \"\" and sp.poll() is not None:\n                    break\n                print(prefix + read, end=\"\")\n                sys.stdout.flush()\n                sys.stderr.flush()\n\n        Thread(target=reader, args=[sp.stdout]).start()\n        Thread(target=reader, args=[sp.stderr]).start()\n\n    elif method == 3:\n        selector = selectors.DefaultSelector()\n        selector.register(sp.stdout, selectors.EVENT_READ)\n        selector.register(sp.stderr, selectors.EVENT_READ)\n        loop = True\n\n        while loop:\n            for key, _ in selector.select():\n                data = key.fileobj.read1().decode()\n                if not data:\n                    loop = False\n                    break\n                line = prefix + str(data).rstrip().replace(\"\\n\", \"\\n\" + prefix)\n                if key.fileobj is sp.stdout:\n                    print(line)\n                    sys.stdout.flush()\n                else:\n                    print(line, file=sys.stderr)\n                    sys.stderr.flush()\n\n\ndef get_free_port():\n    s = socket.socket()\n    s.bind((\"\", 0))\n    port = s.getsockname()[1]\n    s.close()\n    return port\n\n\ndef startContainer(image, command=None, stream_logs=False, onstop=None, **kwargs):\n    global myContainers\n\n    container = dockerClient.containers.run(\n        image,\n        command,\n        # auto_remove=True,\n        detach=True,\n        **kwargs,\n    )\n\n    if stream_logs:\n        log_streamer(container)\n\n    myContainers.append(container)\n\n    def stop():\n        print(\"stop\", container.id)\n        container.stop()\n        container.remove()\n        myContainers.remove(container)\n        if onstop:\n            onstop()\n\n    while container.status != \"running\" and container.status != \"exited\":\n        time.sleep(1)\n        try:\n            container.reload()\n        except Exception as error:\n            print(container.logs())\n            raise error\n        print(container.status)\n\n    # if (container.status == \"exited\"):\n    #  print(container.logs())\n    #  raise Exception(\"unexpected exit\")\n\n    print(\"returned\", container)\n    return container, stop\n\n\n_minioCache = {}\n\n\ndef getMinio(id=\"disposable\"):\n    cached = _minioCache.get(id, None)\n    if cached:\n        return Namespace(**cached)\n\n    if id == \"global\":\n        endpoint_url = os.getenv(\"AWS_S3_ENDPOINT_URL\")\n        if endpoint_url:\n            print(\"Reusing existing global minio\")\n            aws_access_key_id = os.getenv(\"AWS_ACCESS_KEY_ID\")\n            aws_secret_access_key = os.getenv(\"AWS_SECRET_ACCESS_KEY\")\n            aws_s3_default_bucket = AWS_S3_DEFAULT_BUCKET\n            s3 = boto3.client(\n                \"s3\",\n                endpoint_url=endpoint_url,\n                config=boto3.session.Config(signature_version=\"s3v4\"),\n                aws_access_key_id=aws_access_key_id,\n                aws_secret_access_key=aws_secret_access_key,\n                aws_session_token=None,\n                # verify=False,\n            )\n            result = {\n                # don't link to actual container, and don't rm it at end\n                \"container\": \"global\",\n                \"stop\": lambda: print(),\n                # \"port\": port,\n                \"endpoint_url\": endpoint_url,\n                \"aws_access_key_id\": aws_access_key_id,\n                \"aws_secret_access_key\": aws_secret_access_key,\n                \"aws_s3_default_bucket\": aws_s3_default_bucket,\n                \"s3\": s3,\n            }\n            _minioCache.update({id: result})\n            return Namespace(**result)\n        else:\n            print(\"Creating new global minio\")\n\n    port = get_free_port()\n\n    def onstop():\n        del _minioCache[id]\n\n    container, stop = startContainer(\n        \"minio/minio\",\n        \"server /data --console-address :9001\",\n        ports={9000: port},\n        onstop=onstop,\n    )\n\n    endpoint_url = f\"http://{DOCKER_GW_IP}:{port}\"\n\n    while True:\n        time.sleep(1)\n        response = None\n        try:\n            print(endpoint_url + \"/minio/health/live\")\n            response = requests.get(endpoint_url + \"/minio/health/live\")\n        except Exception as error:\n            print(error)\n\n        if response and response.status_code == 200:\n            break\n\n    aws_access_key_id = \"minioadmin\"\n    aws_secret_access_key = \"minioadmin\"\n    aws_s3_default_bucket = AWS_S3_DEFAULT_BUCKET\n    s3 = boto3.client(\n        \"s3\",\n        endpoint_url=endpoint_url,\n        config=boto3.session.Config(signature_version=\"s3v4\"),\n        aws_access_key_id=aws_access_key_id,\n        aws_secret_access_key=aws_secret_access_key,\n        aws_session_token=None,\n        # verify=False,\n    )\n\n    s3.create_bucket(Bucket=AWS_S3_DEFAULT_BUCKET)\n\n    result = {\n        \"container\": container,\n        \"stop\": stop,\n        \"port\": port,\n        \"endpoint_url\": endpoint_url,\n        \"aws_access_key_id\": aws_access_key_id,\n        \"aws_secret_access_key\": aws_secret_access_key,\n        \"aws_s3_default_bucket\": aws_s3_default_bucket,\n        \"s3\": s3,\n    }\n    _minioCache.update({id: result})\n    return Namespace(**result)\n\n\n_ddaCache = None\n\n\ndef getDDA(\n    minio=None,\n    command=None,\n    environment={},\n    stream_logs=False,\n    wait=True,\n    root_cache=True,\n    **kwargs,\n):\n    global _ddaCache\n    if _ddaCache:\n        print(\"return _ddaCache\")\n        return Namespace(**_ddaCache)\n    else:\n        print(\"create new _dda\")\n\n    port = get_free_port()\n\n    environment.update(\n        {\n            \"HF_AUTH_TOKEN\": os.getenv(\"HF_AUTH_TOKEN\"),\n            \"http_proxy\": os.getenv(\"DDA_http_proxy\"),\n            \"https_proxy\": os.getenv(\"DDA_https_proxy\"),\n            \"REQUESTS_CA_BUNDLE\": os.getenv(\"DDA_http_proxy\")\n            and \"/usr/local/share/ca-certificates/squid-self-signed.crt\",\n        }\n    )\n\n    if minio:\n        environment.update(\n            {\n                \"AWS_ACCESS_KEY_ID\": minio.aws_access_key_id,\n                \"AWS_SECRET_ACCESS_KEY\": minio.aws_secret_access_key,\n                \"AWS_DEFAULT_REGION\": \"\",\n                \"AWS_S3_DEFAULT_BUCKET\": minio.aws_s3_default_bucket,\n                \"AWS_S3_ENDPOINT_URL\": minio.endpoint_url,\n            }\n        )\n\n    def onstop():\n        global _ddaCache\n        _ddaCache = None\n\n    HOME = os.getenv(\"HOME\")\n\n    container, stop = startContainer(\n        \"gadicc/diffusers-api:test\",\n        command,\n        stream_logs=stream_logs,\n        ports={8000: port},\n        device_requests=[docker.types.DeviceRequest(count=-1, capabilities=[[\"gpu\"]])],\n        environment=environment,\n        volumes=root_cache and [f\"{HOME}/root-cache:/root/.cache\"],\n        onstop=onstop,\n        **kwargs,\n    )\n\n    url = f\"http://{DOCKER_GW_IP}:{port}/\"\n\n    while wait:\n        time.sleep(1)\n        container.reload()\n        if container.status == \"exited\":\n            if not stream_logs:\n                print(\"--- EARLY EXIT ---\")\n                print(container.logs().decode())\n                print(\"--- EARLY EXIT ---\")\n            raise Exception(\"Early exit before successful healthcheck\")\n\n        response = None\n        try:\n            # print(url + \"healthcheck\")\n            response = requests.get(url + \"healthcheck\")\n        except Exception as error:\n            # print(error)\n            continue\n\n        if response:\n            if response.status_code == 200:\n                result = response.json()\n                if result[\"state\"] == \"healthy\" and result[\"gpu\"] == True:\n                    print(\"Ready\")\n                    break\n                else:\n                    print(response)\n                    print(response.text)\n            else:\n                raise Exception(\"Unexpected status code from dda/healthcheck\")\n\n    data = {\n        \"container\": container,\n        \"stop\": stop,\n        \"minio\": minio,\n        \"port\": port,\n        \"url\": url,\n    }\n\n    _ddaCache = data\n    return Namespace(**data)\n\n\ndef cleanup():\n    print(\"cleanup\")\n    for container in myContainers:\n        print(\"Stopping\")\n        print(container)\n        container.stop()\n        print(\"removing\")\n        container.remove()\n\n\natexit.register(cleanup)\n"
  },
  {
    "path": "tests/integration/requirements.txt",
    "content": "pytest==7.2.0\ndocker==6.0.1\nboto3==1.26.44\nPillow==9.4.0\n# work around breaking changes in urllib3 2.0\n# until https://github.com/docker/docker-py/pull/3114/files lands\nurllib3<2\n"
  },
  {
    "path": "tests/integration/test_attn_procs.py",
    "content": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\nif False:\n\n    class TestAttnProcs:\n        def setup_class(self):\n            print(\"setup_class\")\n            # self.minio = minio = getMinio(\"global\")\n\n            self.dda = dda = getDDA(\n                # minio=minio\n                stream_logs=True,\n            )\n            print(dda)\n\n            self.TEST_ARGS = {\"test_url\": dda.url}\n\n        def teardown_class(self):\n            print(\"teardown_class\")\n            # self.minio.stop() - leave global up\n            self.dda.stop()\n\n        def test_lora_hf_download(self):\n            \"\"\"\n            Download user/repo from HuggingFace.\n            \"\"\"\n            # fp32 model is obviously bigger\n            result = runTest(\n                \"txt2img\",\n                self.TEST_ARGS,\n                {\n                    \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n                    \"MODEL_REVISION\": \"fp16\",\n                    \"MODEL_PRECISION\": \"fp16\",\n                    \"attn_procs\": \"patrickvonplaten/lora_dreambooth_dog_example\",\n                },\n                {\n                    \"num_inference_steps\": 1,\n                    \"prompt\": \"A picture of a sks dog in a bucket\",\n                    \"seed\": 1,\n                    \"cross_attention_kwargs\": {\"scale\": 0.5},\n                },\n            )\n\n            assert result[\"image_base64\"]\n\n        def test_lora_http_download_pytorch_bin(self):\n            \"\"\"\n            Download pytroch_lora_weights.bin directly.\n            \"\"\"\n            result = runTest(\n                \"txt2img\",\n                self.TEST_ARGS,\n                {\n                    \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n                    \"MODEL_REVISION\": \"fp16\",\n                    \"MODEL_PRECISION\": \"fp16\",\n                    \"attn_procs\": \"https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin\",\n                },\n                {\n                    \"num_inference_steps\": 1,\n                    \"prompt\": \"A picture of a sks dog in a bucket\",\n                    \"seed\": 1,\n                    \"cross_attention_kwargs\": {\"scale\": 0.5},\n                },\n            )\n\n            assert result[\"image_base64\"]\n\n        if False:\n            # These formats are not supported by diffusers yet :(\n            def test_lora_http_download_civitai_safetensors(self):\n                result = runTest(\n                    \"txt2img\",\n                    self.TEST_ARGS,\n                    {\n                        \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n                        \"MODEL_REVISION\": \"fp16\",\n                        \"MODEL_PRECISION\": \"fp16\",\n                        \"attn_procs\": \"https://civitai.com/api/download/models/11523\",\n                        \"attn_procs_from_safetensors\": True,\n                    },\n                    {\n                        \"num_inference_steps\": 1,\n                        \"prompt\": \"A picture of a sks dog in a bucket\",\n                        \"seed\": 1,\n                    },\n                )\n\n                assert result[\"image_base64\"]\n"
  },
  {
    "path": "tests/integration/test_build_download.py",
    "content": "import sys\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\ndef test_cloudcache_build_download():\n    \"\"\"\n    Download a model from cloud-cache at build time (no HuggingFace)\n    \"\"\"\n    minio = getMinio()\n    print(minio)\n    environment = {\n        \"RUNTIME_DOWNLOADS\": 0,\n        \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n        \"MODEL_PRECISION\": \"fp16\",\n        \"MODEL_REVISION\": \"fp16\",\n        \"MODEL_URL\": \"s3://\",  # <--\n    }\n    # conda = \"conda run --no-capture-output -n xformers\"\n    conda = \"\"\n    dda = getDDA(\n        minio=minio,\n        stream_logs=True,\n        environment=environment,\n        root_cache=False,\n        command=[\n            \"sh\",\n            \"-c\",\n            f\"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py\",\n        ],\n    )\n    print(dda)\n    assert dda.container.status == \"running\"\n\n    ## bucket.objects.all().delete()\n    result = runTest(\n        \"txt2img\",\n        {\"test_url\": dda.url},\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n        },\n        {\"num_inference_steps\": 1},\n    )\n\n    dda.stop()\n    minio.stop()\n    assert result[\"image_base64\"]\n    print(\"test successs\\n\\n\")\n\n\ndef test_huggingface_build_download():\n    \"\"\"\n    Download a model from HuggingFace at build time (no cloud-cache)\n    NOTE / TODO: Good starting point, but this still runs with gpu and\n    uploads if missing.\n    \"\"\"\n    environment = {\n        \"RUNTIME_DOWNLOADS\": 0,\n        \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n        \"MODEL_PRECISION\": \"fp16\",\n        \"MODEL_REVISION\": \"fp16\",\n    }\n    # conda = \"conda run --no-capture-output -n xformers\"\n    conda = \"\"\n    dda = getDDA(\n        stream_logs=True,\n        environment=environment,\n        root_cache=False,\n        command=[\n            \"sh\",\n            \"-c\",\n            f\"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py\",\n        ],\n    )\n    print(dda)\n    assert dda.container.status == \"running\"\n\n    ## bucket.objects.all().delete()\n    result = runTest(\n        \"txt2img\",\n        {\"test_url\": dda.url},\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n            # \"MODEL_ID\": \"hf-internal-testing/tiny-stable-diffusion-pipe\",\n            \"MODEL_PRECISION\": \"fp16\",\n            \"MODEL_REVISION\": \"fp16\",\n            \"MODEL_URL\": \"\",  # <-- no model_url, i.e. no cloud cache\n        },\n        {\"num_inference_steps\": 1},\n    )\n    dda.stop()\n\n    assert result[\"image_base64\"]\n    print(\"test successs\\n\\n\")\n\n\ndef test_checkpoint_url_build_download():\n    \"\"\"\n    Download and convert a .ckpt at build time.  No cloud-cache.\n    \"\"\"\n    environment = {\n        \"RUNTIME_DOWNLOADS\": 0,\n        \"MODEL_ID\": \"hakurei/waifu-diffusion-v1-3\",\n        \"MODEL_PRECISION\": \"fp16\",\n        \"MODEL_REVISION\": \"fp16\",\n        \"CHECKPOINT_URL\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt\",\n    }\n    # conda = \"conda run --no-capture-output -n xformers\"\n    conda = \"\"\n    dda = getDDA(\n        stream_logs=True,\n        environment=environment,\n        root_cache=False,\n        command=[\n            \"sh\",\n            \"-c\",\n            f\"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py\",\n        ],\n    )\n    print(dda)\n    assert dda.container.status == \"running\"\n\n    ## bucket.objects.all().delete()\n    result = runTest(\n        \"txt2img\",\n        {\"test_url\": dda.url},\n        {\n            \"MODEL_ID\": \"hakurei/waifu-diffusion-v1-3\",\n            \"MODEL_PRECISION\": \"fp16\",\n            \"MODEL_URL\": \"\",  # <-- no model_url, i.e. no cloud cache\n        },\n        {\"num_inference_steps\": 1},\n    )\n    dda.stop()\n\n    assert result[\"image_base64\"]\n    print(\"test successs\\n\\n\")\n"
  },
  {
    "path": "tests/integration/test_cloud_cache.py",
    "content": "import sys\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\ndef test_cloud_cache_create_and_upload():\n    \"\"\"\n    Check if model exists in cloud cache bucket download otherwise, save\n    with safetensors, and upload model.tar.zst to bucket\n    \"\"\"\n    minio = getMinio()\n    print(minio)\n    dda = getDDA(minio=minio, stream_logs=True, root_cache=False)\n    print(dda)\n\n    ## bucket.objects.all().delete()\n    result = runTest(\n        \"txt2img\",\n        {\"test_url\": dda.url},\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n            # \"MODEL_ID\": \"hf-internal-testing/tiny-stable-diffusion-pipe\",\n            \"MODEL_PRECISION\": \"fp16\",\n            \"MODEL_REVISION\": \"fp16\",\n            \"MODEL_URL\": \"s3://\",\n        },\n        {\"num_inference_steps\": 1},\n    )\n\n    dda.stop()\n    minio.stop()\n    timings = result[\"$timings\"]\n    assert timings[\"download\"] > 0\n    assert timings[\"upload\"] > 0\n"
  },
  {
    "path": "tests/integration/test_dreambooth.py",
    "content": "import os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\nHF_USERNAME = os.getenv(\"HF_USERNAME\", \"gadicc\")\n\n\nclass TestDreamBoothS3:\n    \"\"\"\n    Train/Infer via S3 model save.\n    \"\"\"\n\n    def setup_class(self):\n        print(\"setup_class\")\n        self.minio = getMinio(\"global\")\n\n    def teardown_class(self):\n        print(\"teardown_class\")\n        # self.minio.stop() # leave global up.\n\n    def test_training_s3(self):\n        dda = getDDA(\n            minio=self.minio,\n            stream_logs=True,\n        )\n        print(dda)\n\n        result = runTest(\n            \"dreambooth\",\n            {\"test_url\": dda.url},\n            {\n                \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n                \"MODEL_REVISION\": \"\",\n                \"MODEL_PRECISION\": \"\",\n                \"MODEL_URL\": \"s3://\",\n                \"train\": \"dreambooth\",\n                \"dest_url\": f\"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst\",\n            },\n            {\"max_train_steps\": 1},\n        )\n\n        dda.stop()\n        timings = result[\"$timings\"]\n        assert timings[\"training\"] > 0\n        assert timings[\"upload\"] > 0\n\n    # dependent on above, TODO, mark as such.\n    def test_s3_download_and_inference(self):\n        dda = getDDA(\n            minio=self.minio,\n            stream_logs=True,\n            root_cache=False,\n        )\n        print(dda)\n\n        result = runTest(\n            \"txt2img\",\n            {\"test_url\": dda.url},\n            {\n                \"MODEL_ID\": \"model\",\n                \"MODEL_PRECISION\": \"fp16\",\n                \"MODEL_URL\": f\"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst\",\n            },\n            {\"num_inference_steps\": 1},\n        )\n\n        dda.stop()\n        assert result[\"image_base64\"]\n\n\nif os.getenv(\"TEST_DREAMBOOTH_HF\", None):\n\n    class TestDreamBoothHF:\n        def test_training_hf(self):\n            dda = getDDA(\n                stream_logs=True,\n            )\n            print(dda)\n\n            result = runTest(\n                \"dreambooth\",\n                {\"test_url\": dda.url},\n                {\n                    \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n                    \"MODEL_REVISION\": \"\",\n                    \"MODEL_PRECISION\": \"\",\n                    \"MODEL_URL\": \"s3://\",\n                    \"train\": \"dreambooth\",\n                },\n                {\n                    \"hub_model_id\": f\"{HF_USERNAME}/dreambooth_test\",\n                    \"push_to_hub\": True,\n                    \"max_train_steps\": 1,\n                },\n            )\n\n            dda.stop()\n            timings = result[\"$timings\"]\n            assert timings[\"training\"] > 0\n            assert timings[\"upload\"] > 0\n\n        # dependent on above, TODO, mark as such.\n        def test_hf_download_and_inference(self):\n            dda = getDDA(\n                stream_logs=True,\n                root_cache=False,\n            )\n            print(dda)\n\n            result = runTest(\n                \"txt2img\",\n                {\"test_url\": dda.url},\n                {\n                    \"MODEL_ID\": f\"{HF_USERNAME}/dreambooth_test\",\n                    \"MODEL_PRECISION\": \"fp16\",\n                },\n                {\"num_inference_steps\": 1},\n            )\n\n            dda.stop()\n            assert result[\"image_base64\"]\n\nelse:\n\n    print(\n        \"Skipping dreambooth HuggingFace upload/download tests by default\\n\"\n        \"as they can be flaky.  To run, set env var TEST_DREAMBOOTH_HF=1\"\n    )\n"
  },
  {
    "path": "tests/integration/test_general.py",
    "content": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\nclass TestGeneralClass:\n    \"\"\"\n    Typical usage tests, that assume model is already available locally.\n    txt2img, img2img, inpaint.\n    \"\"\"\n\n    CALL_ARGS = {\n        \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n        \"MODEL_PRECISION\": \"fp16\",\n        \"MODEL_REVISION\": \"fp16\",\n        \"MODEL_URL\": \"s3://\",\n    }\n\n    MODEL_ARGS = {\"num_inference_steps\": 2}\n\n    def setup_class(self):\n        print(\"setup_class\")\n        self.minio = minio = getMinio(\"global\")\n\n        self.dda = dda = getDDA(\n            minio=minio\n            # stream_logs=True,\n        )\n        print(dda)\n\n        self.TEST_ARGS = {\"test_url\": dda.url}\n\n    def teardown_class(self):\n        print(\"teardown_class\")\n        # self.minio.stop() - leave global up\n        self.dda.stop()\n\n    def test_txt2img(self):\n        result = runTest(\"txt2img\", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS)\n        assert result[\"image_base64\"]\n\n    def test_img2img(self):\n        result = runTest(\"img2img\", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS)\n        assert result[\"image_base64\"]\n\n    # def test_inpaint(self):\n    #     \"\"\"\n    #     This is actually calling inpaint with SDv2.1, not the inpainting model,\n    #     so I guess we're testing inpaint-legacy.\n    #     \"\"\"\n    #     result = runTest(\"inpaint\", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS)\n    #     assert result[\"image_base64\"]\n"
  },
  {
    "path": "tests/integration/test_loras.py",
    "content": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\nclass TestLoRAs:\n    def setup_class(self):\n        print(\"setup_class\")\n        # self.minio = minio = getMinio(\"global\")\n\n        self.dda = dda = getDDA(\n            # minio=minio\n            stream_logs=True,\n        )\n        print(dda)\n\n        self.TEST_ARGS = {\"test_url\": dda.url}\n\n    def teardown_class(self):\n        print(\"teardown_class\")\n        # self.minio.stop() - leave global up\n        self.dda.stop()\n\n    if False:\n\n        def test_lora_hf_download(self):\n            \"\"\"\n            Download user/repo from HuggingFace.\n            \"\"\"\n            # fp32 model is obviously bigger\n            result = runTest(\n                \"txt2img\",\n                self.TEST_ARGS,\n                {\n                    \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n                    \"MODEL_REVISION\": \"fp16\",\n                    \"MODEL_PRECISION\": \"fp16\",\n                    \"attn_procs\": \"patrickvonplaten/lora_dreambooth_dog_example\",\n                },\n                {\n                    \"num_inference_steps\": 1,\n                    \"prompt\": \"A picture of a sks dog in a bucket\",\n                    \"seed\": 1,\n                    \"cross_attention_kwargs\": {\"scale\": 0.5},\n                },\n            )\n\n            assert result[\"image_base64\"]\n\n    if False:\n\n        def test_lora_http_download_pytorch_bin(self):\n            \"\"\"\n            Download pytroch_lora_weights.bin directly.\n            \"\"\"\n            result = runTest(\n                \"txt2img\",\n                self.TEST_ARGS,\n                {\n                    \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n                    \"MODEL_REVISION\": \"fp16\",\n                    \"MODEL_PRECISION\": \"fp16\",\n                    \"attn_procs\": \"https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin\",\n                },\n                {\n                    \"num_inference_steps\": 1,\n                    \"prompt\": \"A picture of a sks dog in a bucket\",\n                    \"seed\": 1,\n                    \"cross_attention_kwargs\": {\"scale\": 0.5},\n                },\n            )\n\n            assert result[\"image_base64\"]\n\n    # These formats are not supported by diffusers yet :(\n    def test_lora_http_download_civitai_safetensors(self):\n        quickTest = True\n\n        callInputs = {\n            \"MODEL_ID\": \"NED-v1-22\",\n            # https://civitai.com/models/10028/neverending-dream-ned?modelVersionId=64094\n            \"CHECKPOINT_URL\": \"https://civitai.com/api/download/models/64094#fname=neverendingDreamNED_v122BakedVae.safetensors\",\n            \"MODEL_PRECISION\": \"fp16\",\n            # https://civitai.com/models/5373/makima-chainsaw-man-lora\n            \"lora_weights\": \"https://civitai.com/api/download/models/6244#fname=makima_offset.safetensors\",\n            \"safety_checker\": False,\n            \"PIPELINE\": \"lpw_stable_diffusion\",\n        }\n        modelInputs = {\n            # https://civitai.com/images/709482\n            \"num_inference_steps\": 30,\n            \"prompt\": \"masterpiece, (photorealistic:1.4), best quality, beautiful lighting, (ulzzang-6500:0.5), makima \\(chainsaw man\\), (red hair)+(long braided hair)+(bangs), yellow eyes, golden eyes, ((ringed eyes)), (white shirt), (necktie), RAW photo, 8k uhd, film grain\",\n            \"negative_prompt\": \"(painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), bad anatomy\",\n            \"width\": 864,\n            \"height\": 1304,\n            \"seed\": 2281759351,\n            \"guidance_scale\": 9,\n        }\n\n        if quickTest:\n            callInputs.update(\n                {\n                    # i.e. use a model we already have\n                    \"MODEL_ID\": \"runwayml/stable-diffusion-v1-5\",\n                    \"MODEL_REVISION\": \"fp16\",\n                    \"CHECKPOINT_URL\": None,\n                }\n            )\n            modelInputs.update(\n                {\n                    \"num_inference_steps\": 1,\n                    \"width\": 512,\n                    \"height\": 512,\n                }\n            )\n        result = runTest(\"txt2img\", self.TEST_ARGS, callInputs, modelInputs)\n\n        assert result[\"image_base64\"]\n"
  },
  {
    "path": "tests/integration/test_memory.py",
    "content": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\ndef test_memory():\n    \"\"\"\n    Make sure when switching models we release VRAM afterwards.\n    \"\"\"\n    minio = getMinio(\"global\")\n    dda = getDDA(\n        minio=minio,\n        stream_logs=True,\n    )\n    print(dda)\n\n    TEST_ARGS = {\"test_url\": dda.url}\n    MODEL_ARGS = {\"num_inference_steps\": 1}\n\n    mem_usage = list()\n\n    # fp32 model is obviously bigger\n    result = runTest(\n        \"txt2img\",\n        TEST_ARGS,\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n            \"MODEL_REVISION\": \"\",  # <--\n            \"MODEL_PRECISION\": \"\",  # <--\n            \"MODEL_URL\": \"s3://\",\n        },\n        MODEL_ARGS,\n    )\n    mem_usage.append(result[\"$mem_usage\"])\n\n    # fp32 model is obviously smaller\n    result = runTest(\n        \"txt2img\",\n        TEST_ARGS,\n        {\n            \"MODEL_ID\": \"stabilityai/stable-diffusion-2-1-base\",\n            \"MODEL_REVISION\": \"fp16\",  # <--\n            \"MODEL_PRECISION\": \"fp16\",  # <--\n            \"MODEL_URL\": \"s3://\",\n        },\n        MODEL_ARGS,\n    )\n    mem_usage.append(result[\"$mem_usage\"])\n\n    print({\"mem_usage\": mem_usage})\n    assert mem_usage[1] < mem_usage[0]\n\n    dda.stop()\n"
  },
  {
    "path": "touch",
    "content": ""
  },
  {
    "path": "update.sh",
    "content": "#!/bin/sh\n\nrsync -avzPe \"ssh -p $1\" api/ $2:/api/\n"
  }
]