Showing preview only (280K chars total). Download the full file or copy to clipboard to get everything.
Repository: kiri-art/docker-diffusers-api
Branch: dev
Commit: 5521b2e6d63e
Files: 73
Total size: 261.2 KB
Directory structure:
gitextract_akizhtm1/
├── .circleci/
│ └── config.yml
├── .devcontainer/
│ ├── devcontainer.json
│ └── local.example.env
├── .gitignore
├── .vscode/
│ ├── settings.json
│ └── tasks.json
├── CHANGELOG.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── README.md
├── __init__.py
├── api/
│ ├── app.py
│ ├── convert_to_diffusers.py
│ ├── device.py
│ ├── download.py
│ ├── download_checkpoint.py
│ ├── extras/
│ │ ├── __init__.py
│ │ └── upsample/
│ │ ├── __init__.py
│ │ ├── models.py
│ │ └── upsample.py
│ ├── getPipeline.py
│ ├── getScheduler.py
│ ├── lib/
│ │ ├── __init__.py
│ │ ├── prompts.py
│ │ ├── textual_inversions.py
│ │ ├── textual_inversions_test.py
│ │ └── vars.py
│ ├── loadModel.py
│ ├── precision.py
│ ├── send.py
│ ├── server.py
│ ├── status.py
│ ├── tests.py
│ ├── train_dreambooth.py
│ └── utils/
│ ├── __init__.py
│ └── storage/
│ ├── BaseStorage.py
│ ├── BaseStorage_test.py
│ ├── HTTPStorage.py
│ ├── S3Storage.py
│ ├── S3Storage_test.py
│ ├── __init__.py
│ └── __init__test.py
├── build
├── docs/
│ ├── internal_safetensor_cache_flow.md
│ └── storage.md
├── install.sh
├── package.json
├── prime.sh
├── release.config.js
├── requirements.txt
├── run.sh
├── run_integration_tests_on_lambda.sh
├── scripts/
│ ├── devContainerPostCreate.sh
│ ├── devContainerServer.sh
│ ├── patchmatch-setup.sh
│ ├── permutations.yaml
│ └── permute.sh
├── test.py
├── tests/
│ ├── __init__.py
│ └── integration/
│ ├── __init__.py
│ ├── conftest.py
│ ├── lib.py
│ ├── requirements.txt
│ ├── test_attn_procs.py
│ ├── test_build_download.py
│ ├── test_cloud_cache.py
│ ├── test_dreambooth.py
│ ├── test_general.py
│ ├── test_loras.py
│ └── test_memory.py
├── touch
└── update.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .circleci/config.yml
================================================
version: 2.1
jobs:
build:
docker:
- image: cimg/python:3.9-node
resource_class: medium
# would have been nice, but not for $2,000/month!
# machine:
# image: ubuntu-2004-cuda-11.4:202110-01
# resource_class: gpu.nvidia.small
steps:
- checkout
- setup_remote_docker:
docker_layer_caching: true
- run: docker build -t gadicc/diffusers-api .
# unit tests
# - run: docker run gadicc/diffusers-api conda run --no-capture -n xformers pytest --cov=. --cov-report=xml --ignore=diffusers
- run: docker run gadicc/diffusers-api pytest --cov=. --cov-report=xml --ignore=diffusers --ignore=Real-ESRGAN
- run: echo $DOCKER_PASSWORD | docker login --username $DOCKER_USERNAME --password-stdin
# push for non-semver branches (e.g. dev, feature branches)
# - run:
# name: Push to hub on branches not handled by semantic-release
# command: |
# SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\/\/.*//' | jq .branches[])
#
# if [[ ${SEMVER_BRANCHES[@]} =~ "$CIRCLE_BRANCH" ]] ; then
# echo "Skipping because '\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'"
# echo "Semantic-release will handle the publishing"
# else
# echo "docker push gadicc/diffusers-api:$CIRCLE_BRANCH"
# docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH .
# docker push gadicc/diffusers-api:$CIRCLE_BRANCH
# echo "Skipping integration tests"
# circleci-agent step halt
# fi
# needed for later "apt install" steps
- run: sudo apt-get update
## TODO. The below was a great first step, but in future, let's build
# the container on the host, run docker remotely on lambda, and
# publish the same built image if tests pass.
# TODO, only run on main channel for releases (with sem-rel too)
# integration tests
- run: sudo apt install -yqq rsync pv
- run: ./run_integration_tests_on_lambda.sh
- run:
name: Push to hub on branches not handled by semantic-release
command: |
SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\/\/.*//' | jq .branches[])
if [[ ${SEMVER_BRANCHES[@]} =~ "$CIRCLE_BRANCH" ]] ; then
echo "Skipping because '\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'"
echo "Semantic-release will handle the publishing"
else
echo "docker push gadicc/diffusers-api:$CIRCLE_BRANCH"
docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH .
docker push gadicc/diffusers-api:$CIRCLE_BRANCH
# echo "Skipping integration tests"
# circleci-agent step halt
fi
# deploy the image
# - run: docker push company/app:$CIRCLE_BRANCH
# https://github.com/semantic-release-plus/semantic-release-plus/tree/master/packages/plugins/docker
- run:
name: release
command: |
sudo apt-get install yarn
yarn install
yarn run semantic-release-plus
================================================
FILE: .devcontainer/devcontainer.json
================================================
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
{
"name": "Existing Dockerfile",
"build": {
// Sets the run context to one level up instead of the .devcontainer folder.
"context": "..",
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
"dockerfile": "../Dockerfile"
},
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {
"ghcr.io/devcontainers/features/python:1": {
// "version": "3.10"
}
},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
"forwardPorts": [8000],
// Uncomment the next line to run commands after the container is created.
"postCreateCommand": "scripts/devContainerPostCreate.sh",
"customizations": {
"vscode": {
"extensions": [
"ryanluker.vscode-coverage-gutters",
"fsevenm.run-it-on",
"ms-python.black-formatter",
],
"settings": {
"python.pythonPath": "/opt/conda/bin/python"
}
}
},
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
"mounts": [
"source=${localEnv:HOME}/root-cache,target=/root/.cache,type=bind,consistency=cached"
],
"runArgs": [
"--gpus",
"all",
"--env-file",
".devcontainer/local.env"
]
}
================================================
FILE: .devcontainer/local.example.env
================================================
# Useful environment variables:
# AWS or S3-compatible storage credentials and buckets
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_DEFAULT_REGION=
AWS_S3_DEFAULT_BUCKET=
# Only fill this in if your (non-AWS) provider has told you what to put here
AWS_S3_ENDPOINT_URL=
# To use a proxy, e.g.
# https://github.com/kiri-art/docker-diffusers-api/blob/dev/CONTRIBUTING.md#local-https-caching-proxy
# DDA_http_proxy=http://172.17.0.1:3128
# DDA_https_proxy=http://172.17.0.1:3128
# HuggingFace credentials
HF_AUTH_TOKEN=
HF_USERNAME=
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
/lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
permutations
tests/output
node_modules
.devcontainer/local.env
================================================
FILE: .vscode/settings.json
================================================
{
"python.testing.pytestArgs": [
"--cov=.",
"--cov-report=xml",
"--ignore=test.py",
"--ignore=tests/integration",
"--ignore=diffusers",
// "unit_tests.py"
// "."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
// "python.defaultInterpreterPath": "/opt/conda/envs/xformers/bin/python",
"python.defaultInterpreterPath": "/opt/conda/bin/python",
"runItOn": {
"commands": [
{
"match": "\\.py$",
"isAsync": true,
"isShellCommand": false,
"cmd": "testing.runAll"
},
],
},
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
================================================
FILE: .vscode/tasks.json
================================================
{
// See https://go.microsoft.com/fwlink/?LinkId=733558
// for the documentation about the tasks.json format
"version": "2.0.0",
"tasks": [
{
"label": "Watching Server",
"type": "shell",
"command": "scripts/devContainerServer.sh"
}
]
}
================================================
FILE: CHANGELOG.md
================================================
# [1.7.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.6.0...v1.7.0) (2023-09-04)
### Bug Fixes
* **addons:** async TI download status, LoRA improvements ([de8cfdc](https://github.com/kiri-art/docker-diffusers-api/commit/de8cfdc63d7ae46bed90862fe3bffe65534d3e55))
* **circleci:** pytest --ignore=Real-ESRGAN ([d7038b5](https://github.com/kiri-art/docker-diffusers-api/commit/d7038b5aa54c8b3dab2149ea773e007b9c0202ce))
* **circleci:** remove conda from pytest call ([2f29af2](https://github.com/kiri-art/docker-diffusers-api/commit/2f29af2c012ef38ed2e2bc0ec116b59b8c429e57))
* **diffusers:** bump to aae2726 (jul30) post v0.19.2 + fixes ([6c0a10a](https://github.com/kiri-art/docker-diffusers-api/commit/6c0a10a743abb7cd12cce9bf1cc6a598c6804e92))
* **Dockerfile:** -yqq for apt-get, apt-utils, extra deps ([bf470da](https://github.com/kiri-art/docker-diffusers-api/commit/bf470dabb9b3c6d7f16d11126ffef0f4ee4806f5))
* **Dockerfile:** TZ tzdata fix ([9c5d911](https://github.com/kiri-art/docker-diffusers-api/commit/9c5d911aafedc1a2dab94a5c1c1c25aa4bc0ce7a))
* **misc:** fix failing tests, pipeline init in rare circumstances ([9338648](https://github.com/kiri-art/docker-diffusers-api/commit/933864893a35dfb9fa093b988a5b159af4e0a9ca))
* **prime/update:** commit these useful utility scripts ([7b167c0](https://github.com/kiri-art/docker-diffusers-api/commit/7b167c0508e7a476d8c6719e056d6bdfa255e2d8))
* **upsample:** return $meta for kiri ([b9dd6b7](https://github.com/kiri-art/docker-diffusers-api/commit/b9dd6b780005ad17090220fba99f0329b98f9c09))
* **x_attn_kwargs:** only pass to pipeline if set ([3f1f980](https://github.com/kiri-art/docker-diffusers-api/commit/3f1f980930edb9bad28c6c026d31ca084887b442))
### Features
* **checkpoints:** use correct pipeline for "inpaint" in path ([16dd383](https://github.com/kiri-art/docker-diffusers-api/commit/16dd38327d291de29da012026a2ffcede0681526))
* **loras:** ability to specify #?scale=0.1 -> cross_attn_kwargs ([747fc0d](https://github.com/kiri-art/docker-diffusers-api/commit/747fc0ddec1db91617fb01f4d7ef9b8291de221d))
* **pytorch2:** bump deps, drop conda/xformers ([a3d8078](https://github.com/kiri-art/docker-diffusers-api/commit/a3d807896e2b0d831580b78be556fcc69be08353))
* **sdxl,compel:** Support. AutoPipeline default, safety_check fix ([993be12](https://github.com/kiri-art/docker-diffusers-api/commit/993be124c2e5b0f04b1cf25ca285e3a6573ce19a))
* **sdxl:** fix sd_xl, loras; ability to init load specific pipeline ([7e3af77](https://github.com/kiri-art/docker-diffusers-api/commit/7e3af77167b58481d3c974ae33c3991ef976fc28))
* **textualInversion:** very early support ([2babd53](https://github.com/kiri-art/docker-diffusers-api/commit/2babd539a6fcb396bb1f323fe9c50cdccb91cf96))
* **upsample:** initial RealESRGAN support for runtime downloads ([8929508](https://github.com/kiri-art/docker-diffusers-api/commit/8929508adea8cd0e50ccf79aaea2a13354f37fa8))
# [1.6.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.5.0...v1.6.0) (2023-07-12)
### Bug Fixes
* **BaseStorage:** mv misplaced .query from BaseArchive to BaseStorage ([0c7a757](https://github.com/kiri-art/docker-diffusers-api/commit/0c7a757634cb62bacb3efda7f9a6e4b85bb3cb4e))
* **conversion:** recognize "safetensor" anywhere in filename ([1ceab7d](https://github.com/kiri-art/docker-diffusers-api/commit/1ceab7dfb1d0d507b3b61f777453d81caf5190c2))
* **deps:** bump diffusers to b9feed8, lock bitsandbytes==0.39.1 ([be1c322](https://github.com/kiri-art/docker-diffusers-api/commit/be1c32218cd0e312077de2b7a10b41f2f5be07e0))
* **deps:** diffusers to 0.17.0 + latest commits, other packages ([a6e9db0](https://github.com/kiri-art/docker-diffusers-api/commit/a6e9db09382d972da3c6c08786ff92986e7585b7))
* **pipelines:** pass revision/precision for community pipelines too ([20311cf](https://github.com/kiri-art/docker-diffusers-api/commit/20311cf51babf16609af1495585a4e9fca1f05e4))
* **safety_checker:** drop DummySafetyChecker and just use None ([e4fbf22](https://github.com/kiri-art/docker-diffusers-api/commit/e4fbf225e0f09c8591f2537e3061977fad6386ed))
### Features
* **checkpoints:** support #fname query in HTTPStorage ([0cb839d](https://github.com/kiri-art/docker-diffusers-api/commit/0cb839db75f86c07d568b4a379bedba971340eb0))
* **dreambooth:** update / merge in all upstream changes to date ([a40129a](https://github.com/kiri-art/docker-diffusers-api/commit/a40129a2b2f47282cc463d1249985d4b07ec16c9))
* **loras:** use load_lora_weights (works with A1111 files too) ([7a64846](https://github.com/kiri-art/docker-diffusers-api/commit/7a6484642a11fc3f3de780d4627de2dd48607d89))
* **storage:** allow #a=1&b=2 params; HTTP can use #fname=XXX ([4fe13ef](https://github.com/kiri-art/docker-diffusers-api/commit/4fe13ef7fbd4948e5f665e3d38a57430def561b8))
# [1.5.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.4.0...v1.5.0) (2023-05-24)
### Bug Fixes
* **app:** async fixes for download, train_dreambooth ([0dcbd16](https://github.com/kiri-art/docker-diffusers-api/commit/0dcbd16c1a85a9f3fb867a28d66b00f0eccaba80))
* **app:** diffusers callback cannot be async; use asyncio.run() ([7854649](https://github.com/kiri-art/docker-diffusers-api/commit/7854649011d370497690618fe3ea0e8ce2c79bc6))
* **app:** up sanic RESPONSE_TIMEOUT from 1m to 1hr ([8e2003a](https://github.com/kiri-art/docker-diffusers-api/commit/8e2003afad8af93d4e1442138d6b7673e32af971))
* **attn_procs:** apply workaround only for storage not hf repos ([b98710f](https://github.com/kiri-art/docker-diffusers-api/commit/b98710f144265df3d77a90bfb39d2dd30fbd8c96))
* **attn_procs:** load non-safetensors attn_procs ourself ([072e7a3](https://github.com/kiri-art/docker-diffusers-api/commit/072e7a38f13d66b3e069427c318e16dcd5b6324d)), closes [/github.com/huggingface/diffusers/pull/2448#issuecomment-1453938119](https://github.com//github.com/huggingface/diffusers/pull/2448/issues/issuecomment-1453938119)
* **deps:** pin websockets<11.0 for sanic ([33ae2f4](https://github.com/kiri-art/docker-diffusers-api/commit/33ae2f4c905c5e92aa9ff6cc2f61a3adb81b1b59))
* **inference:** return $error NO_MODEL_ID vs later crash on None ([46ea977](https://github.com/kiri-art/docker-diffusers-api/commit/46ea977cea6e469059931d722df5a38a3f931d77))
* **storage:** actually, always set self.status (default None) ([c309ca9](https://github.com/kiri-art/docker-diffusers-api/commit/c309ca92fd1038f89dae186e35cc732e5822c8c2))
* **storage:** don't set self.status to None ([9b88b80](https://github.com/kiri-art/docker-diffusers-api/commit/9b88b8089c4063e63aab547ce945ebb1a94f2fd7))
* **storage:** extract with dir= must not mutate dir (download, logs) ([b1f8f87](https://github.com/kiri-art/docker-diffusers-api/commit/b1f8f87756f61ae0aa61c3785911ab043f911d98))
* **tests:** pin urlllib3 to < 2, avoids break in docker package ([ccf8231](https://github.com/kiri-art/docker-diffusers-api/commit/ccf823139ac0f379e2f27d8dd5921f5343f20f8a))
### Features
* **app:** run pipeline via asyncio.to_thread ([e87f7e7](https://github.com/kiri-art/docker-diffusers-api/commit/e87f7e772fa1f5f22957600572be60b150999095))
* **attn_procs:** from_safetensors override, save .savetensors fname ([5fb6487](https://github.com/kiri-art/docker-diffusers-api/commit/5fb6487579d8b809c52f9451c68bcfcafecca0f0))
* **cors:** add sanic-ext and set default cors-origin to "*" ([eb2a385](https://github.com/kiri-art/docker-diffusers-api/commit/eb2a385684a309557b637d7c03f2e8cda00137b0))
* **diffusers:** bump to 0.15.0 + 2 weeks with lpw fix (9965cb5) ([77e9078](https://github.com/kiri-art/docker-diffusers-api/commit/77e907892b5b6b9b27aa75f5ec5732a81ba784d6))
* **diffusers:** bump to latest diffusers, 0.14 + patches (see note) ([48a99a5](https://github.com/kiri-art/docker-diffusers-api/commit/48a99a532503bf9f8932f64ddf20d7b81aab765b))
* **download:** async, status; download.py: use download_and_extract ([bb7434a](https://github.com/kiri-art/docker-diffusers-api/commit/bb7434a4e39d02dce5ecbf602fe6e41511481c12))
* **HTTPStorage:** store filename from content-disposition ([2066c44](https://github.com/kiri-art/docker-diffusers-api/commit/2066c446ba058209d1f594a46a8af0188e6e82fa))
* **loadModel:** send loadModel status ([db75740](https://github.com/kiri-art/docker-diffusers-api/commit/db75740177688e25bba4066d099a2c034dd3eb93))
* **status:** initial status work ([d1cd39e](https://github.com/kiri-art/docker-diffusers-api/commit/d1cd39ea93e4c967be91ed59b8b05a6ce9f117da))
* **storage:** support misc tar compression; progress ([a8c8337](https://github.com/kiri-art/docker-diffusers-api/commit/a8c8337da4b750f92f9712397293da20974aa385))
* **stream_events:** stream send()'s to client too ([08daf4f](https://github.com/kiri-art/docker-diffusers-api/commit/08daf4fdca1f3ad23965e9bf14a3b66fc57279fd))
# [1.4.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.3.0...v1.4.0) (2023-02-28)
### Bug Fixes
* **checkpoints:** new conversion pipeline + convert w/o MODEL_URL ([cd7f54d](https://github.com/kiri-art/docker-diffusers-api/commit/cd7f54db370462f6c3e7ecb37df791388a9ccd34))
* **diffusers:** bump to latest commit (includes v0.13.1) ([400e3d7](https://github.com/kiri-art/docker-diffusers-api/commit/400e3d7b0897e966ba3c1cc04194aedde8746edf))
* **diffusers:** bump to recent commit, includes misc LoRA fixes ([7249c30](https://github.com/kiri-art/docker-diffusers-api/commit/7249c307a9c2892a061398e75cd70965329c3ac6))
* **loadModel:** pass revision arg too ([cd5f995](https://github.com/kiri-art/docker-diffusers-api/commit/cd5f995dad9123aa4ea066ad4b9d369ef01df06b))
### Features
* **attn_procs:** initial URL work (see notes) ([6348836](https://github.com/kiri-art/docker-diffusers-api/commit/6348836622da4a17fa0e423ca9b92ebb489b4793))
* **callback:** if modelInput.callback_steps, send() current step ([2279de1](https://github.com/kiri-art/docker-diffusers-api/commit/2279de103d70614fbdee620024941dd1db81c436))
* **gpu:** auto-detect GPU (CUDA/MPS/cpu), remove hard-coded ([#20](https://github.com/kiri-art/docker-diffusers-api/issues/20)) ([682a342](https://github.com/kiri-art/docker-diffusers-api/commit/682a34221f5b586fd0d8e9c0789201cb238cf225))
* **lora:** callInput `attn_procs` to load LoRA's for inference ([cb54291](https://github.com/kiri-art/docker-diffusers-api/commit/cb542910fd234af0a02a862934bf5c090384500d))
* **send:** set / override SEND_URL, SIGN_KEY via callInputs ([74b4c53](https://github.com/kiri-art/docker-diffusers-api/commit/74b4c53bd49691df087364959123cfd48e04ac59))
# [1.3.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.2.2...v1.3.0) (2023-01-26)
### Bug Fixes
* **diffusers:** bump to v0.12.0 ([635d9d9](https://github.com/kiri-art/docker-diffusers-api/commit/635d9d97a010c49ef7875fcb4b43b668848ced0b))
* **diffusers:** update to latest commit ([87632aa](https://github.com/kiri-art/docker-diffusers-api/commit/87632aa2c32faddfeb049fe969884b568066edd3))
* **dreambooth:** bump diffusers, fixes fp16 mixed precision training ([0f5d5ff](https://github.com/kiri-art/docker-diffusers-api/commit/0f5d5ff2bf5b73260b9d60521389f0938f205219))
* **dreambooth:** merge commits to v0.12.0 (NB: mixed-precision issue) ([88f04f8](https://github.com/kiri-art/docker-diffusers-api/commit/88f04f870814aa9baf2a7c09513dcc796070b814))
* **pipelines:** fix clearPipelines() backport from cloud-cache ([9577f93](https://github.com/kiri-art/docker-diffusers-api/commit/9577f9344f0060edc185e32eadeb57e83551aa7f))
* **requirements:** bump transformers,accelerate,safetensors & others ([aebcf65](https://github.com/kiri-art/docker-diffusers-api/commit/aebcf6562808a817e6ee29e88f178f22f54c861b))
* **re:** use raw strings r"" for regexps ([41310c2](https://github.com/kiri-art/docker-diffusers-api/commit/41310c26bbc19069db492781313b162f0fc4d7d9))
* **tests/lambda:** export HF_AUTH_TOKEN ([9f11e7b](https://github.com/kiri-art/docker-diffusers-api/commit/9f11e7b2f0d2a377a44b22d446274677bd025813))
* **test:** shallow copy to avoid mutating base test inputs ([8c41167](https://github.com/kiri-art/docker-diffusers-api/commit/8c41167461308b14066be1472fd8957dc6cdd658))
### Features
* **downloads:** RUNTIME_DOWNLOAD from HF when no MODEL_URL given ([73784a1](https://github.com/kiri-art/docker-diffusers-api/commit/73784a1844ef2b14c628eb399bec0e52661df35c))
## [1.2.2](https://github.com/kiri-art/docker-diffusers-api/compare/v1.2.1...v1.2.2) (2023-01-09)
### Bug Fixes
* **dreambooth:** runtime_dls path fix; integration tests ([ce3827f](https://github.com/kiri-art/docker-diffusers-api/commit/ce3827f6aabd5158c39c99ffae0358d832de2e39))
* **loadModel:** revision = None if revision == "" else revision ([1773631](https://github.com/kiri-art/docker-diffusers-api/commit/1773631e292e28fae20b0a6c93406378aed85d47))
## [1.2.1](https://github.com/kiri-art/docker-diffusers-api/compare/v1.2.0...v1.2.1) (2023-01-05)
### Bug Fixes
* **build-download:** support regular HF download not just cloud cache ([52edf6b](https://github.com/kiri-art/docker-diffusers-api/commit/52edf6b8e52cba4a03c8ea0f72b8fd1e69fa87ad))
# [1.2.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.1.0...v1.2.0) (2023-01-04)
### Features
* **build:** separate MODEL_REVISION, MODEL_PRECISION, HF_MODEL_ID ([fa9dd16](https://github.com/kiri-art/docker-diffusers-api/commit/fa9dd16b7369d37f3997ef46581df471bca8e7c1))
# [1.1.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.0.2...v1.1.0) (2023-01-04)
### Features
* **downloads:** allow HF_MODEL_ID call-arg (defauls to MODEL_ID) ([adaa7f6](https://github.com/kiri-art/docker-diffusers-api/commit/adaa7f67aba49058b2e52117e6eb0fed6417b773))
* **downloads:** allow separate MODEL_REVISION and MODEL_PRECISION ([6edc821](https://github.com/kiri-art/docker-diffusers-api/commit/6edc821da1593f34e4502352dba8f2f4cd808e95))
## [1.0.2](https://github.com/kiri-art/docker-diffusers-api/compare/v1.0.1...v1.0.2) (2023-01-01)
### Bug Fixes
* **diffusers:** bump to 2022-12-30 commit 62608a9 ([2f29165](https://github.com/kiri-art/docker-diffusers-api/commit/2f291655967a253b81da9f44c99d4ac68e1c8353))
## [1.0.1](https://github.com/kiri-art/docker-diffusers-api/compare/v1.0.0...v1.0.1) (2022-12-31)
### Bug Fixes
* **ci:** different token, https auth ([ecd0b5d](https://github.com/kiri-art/docker-diffusers-api/commit/ecd0b5d8efe734693ff9647cfc2d0bc0b8f90e42))
# 1.0.0 (2022-12-31)
### Bug Fixes
* **app:** clearPipelines() before loadModel() to free RAM ([ec45acf](https://github.com/kiri-art/docker-diffusers-api/commit/ec45acf7db7796682597d1d1c440d3742df84425))
* **app:** init: don't process MODEL_ID if not RUNTIME_DOWNLOADS ([683677f](https://github.com/kiri-art/docker-diffusers-api/commit/683677f0bdbd49c11cb0310c7c365047b536a4f7))
* **dockerfile:** bump diffusers to eb1abee693104dd45376dbddd614320f2a0beb24 ([1769330](https://github.com/kiri-art/docker-diffusers-api/commit/1769330d4ec1f5932591383daf078be0953accdc))
* **downloads:** model_url, model_id should be optional ([9a19e7e](https://github.com/kiri-art/docker-diffusers-api/commit/9a19e7e1e742c46471f9a7e6fcebacea5f887d35))
* **dreambooth:** don't crash on cleanup when no class_data_dir created ([36e64b1](https://github.com/kiri-art/docker-diffusers-api/commit/36e64b101bb12c7e09445f5958acaab1ab59a301))
* **dreambooth:** enable mixed_precision training, default to fp16 ([0430d23](https://github.com/kiri-art/docker-diffusers-api/commit/0430d2380b5c6e5e43f2c8657017ba701bfaec41))
* **gitScheduler:** fix deprecation warning s/from_config/from_pretrained/ ([92b2b43](https://github.com/kiri-art/docker-diffusers-api/commit/92b2b433bd9dfb4e1af1473cfa430e55bc83b170))
* **pipelines:** community pipelines, set torch_dtype too ([0cc1b63](https://github.com/kiri-art/docker-diffusers-api/commit/0cc1b63f72f98ad9267cdc71707bb4b533ad303d))
* **pipelines:** fix clearPipelines(), load model w/ correct precision ([3085412](https://github.com/kiri-art/docker-diffusers-api/commit/308541243c78cf528ebcd4c68900f5cdd52e6f8f))
* **requirements:** bumps transformers from 4.22.2 to 4.25.1 ([b13b58c](https://github.com/kiri-art/docker-diffusers-api/commit/b13b58c89fcd30e90ebb58c193c803450db43ebd))
* **s3:** incorrect value for tqdm causing crash ([9527ece](https://github.com/kiri-art/docker-diffusers-api/commit/9527ece90e4b5b4366f1c418d837dd659764203c))
* **send:** container_id detection, use /containers/ to grep ([5c0606a](https://github.com/kiri-art/docker-diffusers-api/commit/5c0606a0fdfd9b1a410b6f96eff009da6b768dbe))
* **tests:** default to DPMSolverMultistepScheduler and 20 steps ([a9c7bb0](https://github.com/kiri-art/docker-diffusers-api/commit/a9c7bb091821640a84d37d3090d365b7a54f2615))
### Features
* ability for custom config.yaml in CHECKPOINT_CONFIG_URL ([d2b507c](https://github.com/kiri-art/docker-diffusers-api/commit/d2b507ca225a033dda35897999e489541faecb8c))
* add PyPatchMatch for outpainting support ([3675bd3](https://github.com/kiri-art/docker-diffusers-api/commit/3675bd31a12d7b1f9627e34f59b661ea7261c272))
* **app:** don't track downloads in mem, check on disk ([51729e2](https://github.com/kiri-art/docker-diffusers-api/commit/51729e21440e4f0721b73ea497ddd2136306f11d))
* **app:** runtime downloads with MODEL_URL ([7abc4ac](https://github.com/kiri-art/docker-diffusers-api/commit/7abc4aced15f4aec441d4c220f39e046d2e35179))
* **app:** runtime downloads, re-use loaded model if requested again ([b84e822](https://github.com/kiri-art/docker-diffusers-api/commit/b84e822cacdb249693a301eb62a600ac9e0ee8f9))
* **callInputs:** `MODEL_ID`, `PIPELINE`, `SCHEDULER` now optional ([ef420a1](https://github.com/kiri-art/docker-diffusers-api/commit/ef420a1022b3d80950e7df79f1aff006e775c313))
* **cloud_cache:** normalize model_id and include precision ([ad1b2ef](https://github.com/kiri-art/docker-diffusers-api/commit/ad1b2efc60216c7a8854139ae816d78f6c4a9a19))
* **diffusers:** bump to v0.10.12 and one commit after (6b68afd) ([ec9117b](https://github.com/kiri-art/docker-diffusers-api/commit/ec9117b747985b7b3d80a4211c4e7bf6253a24a1))
* **diffusers:** bump to v0.9.0 ([0504d97](https://github.com/kiri-art/docker-diffusers-api/commit/0504d97e38eb85924ef7453c3c8690428f54870d))
* **docker:** diffusers-api-base image, build, run.sh ([1cbfc4f](https://github.com/kiri-art/docker-diffusers-api/commit/1cbfc4f41b46ea8d38600ac6902cf5f095357344))
* **dockerfile:** FROM_IMAGE build-arg to pick base image ([a0c37a6](https://github.com/kiri-art/docker-diffusers-api/commit/a0c37a6a87b300771f6ecf168b8bb1516caa5ab9))
* **Dockerfile:** make SDv2 the default (+ some formatting cleanup) ([c1e73ef](https://github.com/kiri-art/docker-diffusers-api/commit/c1e73efcdb6e5c95d36c83f9d1398182a1b7e77e))
* **dockerfile:** runtime downloads ([b40ae86](https://github.com/kiri-art/docker-diffusers-api/commit/b40ae868ce59ddb0232bcdb27ebb0a2c91068f51))
* **Dockerfile:** SAFETENSORS_FAST_GPU ([62209be](https://github.com/kiri-art/docker-diffusers-api/commit/62209be9963f9699ba32ea7520a361545b55034e))
* **download:** default_path as normalized_model_id.tar.zst ([5ad0d88](https://github.com/kiri-art/docker-diffusers-api/commit/5ad0d88b0b9b5a5a07596457c3bc83b7b32b25f5))
* **download:** delete .zst file after uncompress ([ab25280](https://github.com/kiri-art/docker-diffusers-api/commit/ab25280125bc1ccc38a0a2588fc09e33a576f6b0))
* **download:** record download timings ([7457e50](https://github.com/kiri-art/docker-diffusers-api/commit/7457e505c826c44d9f45a05fe486e819d442b4ca))
* **downloads:** runtime checkpoint conversion ([2414cd9](https://github.com/kiri-art/docker-diffusers-api/commit/2414cd9e3ac232273a1f2441134c65c25d0f7b49))
* **dreambooth:** save in safetensors format, tar up with -v ([5c3e86a](https://github.com/kiri-art/docker-diffusers-api/commit/5c3e86a8f99331c41c34b36c932b70e11f7b80b0))
* **errors:** try...catch everything, return as JSON ([901679c](https://github.com/kiri-art/docker-diffusers-api/commit/901679c7829796dc585af25f658cd6ab9115c7e7))
* **getScheduler:** make DPMSolverMultistepScheduler the default ([085d06f](https://github.com/kiri-art/docker-diffusers-api/commit/085d06f6b993a24b16521a1c3ee77d92289e04ed))
* **k-diffusion:** add pip package for use in k-diffusion shedulers ([3e901ad](https://github.com/kiri-art/docker-diffusers-api/commit/3e901adc64f750f5501b5dd19d87d0a5e294de22))
* **models:** store in ~/.cache/diffusers-api (volume support) ([8032ec1](https://github.com/kiri-art/docker-diffusers-api/commit/8032ec11b8f6590015110c9b89437f5619f2374c))
* **pipelines:** allow calling of ALL PIPELINES (official+community) ([1ccbaad](https://github.com/kiri-art/docker-diffusers-api/commit/1ccbaad1f405b8e5d16ca1a9880cc1d279f6d3f9))
* **pipelines:** initial community pipeline support ([7af45cf](https://github.com/kiri-art/docker-diffusers-api/commit/7af45cfdc4cbcc95c905834628775d0e8858509e))
* **s3:** s3client(), file_exists() methods ([0308af9](https://github.com/kiri-art/docker-diffusers-api/commit/0308af910d07be6d912104663263663b086def9c))
* **s3:** upload/download progress indicators ([76dd303](https://github.com/kiri-art/docker-diffusers-api/commit/76dd303a58a57b90ecc2c0038547b23b906ecca5))
* **send:** prefer env var CONTAINER_ID if set to full docker uuid ([eec5112](https://github.com/kiri-art/docker-diffusers-api/commit/eec511252035b8205f5365f45abb5777c164cb57))
* **send:** SEND_URL and SIGN_KEY now settable with build-vars ([01cf354](https://github.com/kiri-art/docker-diffusers-api/commit/01cf35461c5855a75651a30e3aeccb4ad1e9c8ac))
* **test:** allow TEST_URL to override https://localhost:8000/ ([9b46387](https://github.com/kiri-art/docker-diffusers-api/commit/9b463872257c0a3ffae553765aed62a2df6af717))
* **tests:** allow override BANANA_API_URL ([aca6aca](https://github.com/kiri-art/docker-diffusers-api/commit/aca6aca6e7ed46d0bf711548cea82a588fdd7d2a))
# CHANGELOG
* **NEXT MAIN**
* Callinputs `MODEL_ID`, `PIPELINE` and `SCHEDULER` are **now optional**.
If not specified, the default will be used, and returned in a `$meta`
key in the result.
* Tests: 1) Don't specify above defaults where possible, 2) Log exact
inputs sent to container, 3) Log the full result sent back,
substituting base64 image strings with their info, 4) format stack
traces on caught errors from container.
* **NEXT MAIN (and already posted to forum)**
* **Latest diffusers, SDv2.1**. All the latest goodness, and upgraded some
dependencies too. Models are:
* `stabilityai/stable-diffusion-2-1-base` (512x512)
* `stabilityai/stable-diffusion-2-1` (768x768)
* **ALL THE PIPELINES**. We no longer load a list of hard-coded pipelines
in `init()`. Instead, we init and cache each on first use (for faster
first calls on cold boots), and, *all* pipelines, both official diffusers
and community pipelines, are available.
[Full details](https://banana-forums.dev/t/all-your-pipelines-are-belong-to-us/83)
* Dreambooth: Enable `mixed_precision` training, default to fp16.
* [Experimental] **[Runtime downloads](https://banana-forums.dev/t/runtime-downloads-dont-download-during-build/81/3)** (Dreambooth
only for now, more on the way)
* **S3**: Add upload/download progress indicators.
* Stable Diffusion has standardized **`image` instead of `init_image`** for
all pipelines. Using `init_image` now shows a deprecation warning and
will be removed in future.
* **Changed `sd-base` to `diffusers-api`** as the default tag / name used
in the README examples and optional [./build][build script].
* **Much better error handling**. We now `try...except` both the pipeline
run and entire `inference()` call, which will save you a trip to banana's
logs which don't always even show these errors and sometimes just leave
you with an unexplained stuck instance. These kinds of errors are almost
always a result of problematic callInputs and modelInputs used for the
pipeline call, so finding them will be a lot easier now.
* **2022-11-29**
* **Diffusers v0.9.0, Stable Diffusion v2.0**. Models:
* `"stabilityai/stable-diffusion-2"` - trained on 768x768
* `"stabilityai/stable-diffusion-2-base"` - trained on 512x512
* `"stabilityai/stable-diffusion-2-inpainting"` - untested
* `""stabilityai/stable-diffusion-x4-upscaler"` - untested
> https://github.com/huggingface/diffusers/releases
**NB**: SDv2 does not include a safety_checker. The model itself is
"safe" (it's much harder to create NSFW content). Trying to "turn off"
the (non-existent) safety checker will throw an error, we'll handle this
more gracefully in a future release. This also means you can safely
ignore this warning on loading:
```
You have disabled the safety checker for
<class diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>
by passing safety_checker=None. Ensure that...
```
* **DPMSolverMultistepScheduler**. Docker-diffusers-api is simply a wrapper
around diffusers. We support all the included schedulers out of the box,
as long as they can init themselves with default arguments. So, the above
scheduler was already working, but we didn't mention it before. I'll just
quote diffusers:
> DPMSolverMultistepScheduler is the firecracker diffusers implementation
of DPM-Solver++, a state-of-the-art scheduler that was contributed by one
of the authors of the paper. This scheduler is able to achieve great
quality in as few as 20 steps. It's a drop-in replacement for the default
Stable Diffusion scheduler, so you can use it to essentially half
generation times.
* **Storage Class / S3 support**. We now have a generic storage class, which
allows for special URLs anywhere anywhere you can usually specify a URL,
e.g. `CHECKPOINT_URL`, `dest_url` (after dreambooth training), and the new
`MODEL_URL` (see below). URLs like "s3:///bucket/filename" will work how
you expect, but definitely read [docs/storage.md](./docs/storage.md)
to understand the format better. Note in particular the triple forwardslash
("///") in the beginning to use the default S3 endpoint.
* **Dreambooth training**, working but still in development. See
[this forum post](https://banana-forums.dev/t/dreambooth-training-first-look/36)
for more info.
* **`PRECISION`** build var, defaults to `"fp16"`, set to `""` to use the model
defaults (generally fp32).
* **`CHECKPOINT_URL` conversion**:
* Crash / stop build if conversion fails (rather than unclear errors later on)
* Force `cpu` loading even for models that would otherwise default to GPU.
This fixes certain models that previously crashed in build stage (where GPU
is not available).
* `--extract-ema` on conversion since these are the more important weights for
inference.
* `CHECKPOINT_CONFIG_URL` now let's to specify a specific config file for
conversion, to use instead of SD's default `v1-inference.yaml`.
* **`MODEL_URL`**. If your model is already in diffusers format, but you don't
host it on HuggingFace, you can now have it downloaded at build time. At
this stage, it should be a `.tar.zst` file. This is an *alternative* to
`CHECKPOINT_URL` which downloads a `.ckpt` file and converts to diffusers.
* **`test.py`**:
* New `--banana` arg to run the test on banana. Set environment variables
`BANANA_API_KEY` and `BANANA_MODEL_KEY` first.
* You can now add to and override a test's default json payload with:
* `--model-arg prompt="hello"`
* `--call-arg MODEL_ID="my-model"`
* Support for extra timing data (e.g. dreambooth sends `train`
and `upload` timings).
* Quit after inference errors, don't keep looping.
* **Dev: better caching solution**. No more unruly `root-cache` directory. See
[CONTRIBUTING.md](./CONTRIBUTING.md) for more info.
* **2022-11-08**
* **Much faster `init()` times!** For `runwayml/stable-diffusion-v1-5`:
* Previously: 4.0s, now: 2.4s (40% speed gain)
* **Much faster `inference()` times!** Particularly from the 2nd inference onwards.
Here's a brief comparison of *inference* average times (for 512x512 x50 steps):
* [Cold] Previously: 3.8s, now: 3.3s (13% speed gain)
* [Warm] Previously: 3.2s, now: 2.1s (34% speed gain)
* **Improved `test.py`**, see [Testing](./README.md#testing)
* **2022-11-05**
* Upgrade to **Diffusers v0.7.0**. There is a lot of fun stuff in this release,
but notably for docker-diffusers-api TODAY (more fun stuff coming next week!),
we have **much faster init times** (via
[`fast_load`](https://github.com/huggingface/diffusers/commit/7482178162b779506a54538f2cf2565c8b88c597)
) and the greatly anticipated support for the Euler schedulers (
[a1ea8c0](https://github.com/huggingface/diffusers/commit/a1ea8c01c31a44bf48f6a3b85ccabeb45ef6418f)
).
* We now use the **full scheduler name** for `callInputs.SCHEDULER`. `"LMS"`,
`"DDIM"`, `"PNDM"` all still work fine for now but give a deprecation warning
and will stop working in a future update. The full list of supported schedulers
is: `LMSDiscreteScheduler`, `DDIMScheduler`, `PNDMScheduler`,
`EulerAncestralDiscreteScheduler`, `EulerDiscreteScheduler`. These cover the
most commonly used / requested schedulers, but we already have code in place to
support every scheduler provided by diffusers, which will work in a later
diffusers release when they have better defaults.
* **2022-10-24**
* **Fixed img2img and inpainting pipelines**. To my great shame, in my rush to get
the new models out before the weekend, I inadvertently broke the above two models.
Please accept my sincere apology for any confusion this may have caused and
especially any of your wasted time in debugging this 🙇
* **Event logs now shown without `SEND_URL`**. We optionally log useful info at the
start and end of `init()` and `inference()`. Previously this was only logged if
`SEND_URL` was set, to send to an external REST API for logging. But now, even if
we don't send it anywhere, we'll still log this useful info. It now also logs
the `diffusers` version too.
* **2022-10-21**
* **Stable Diffusion 1.5 released!!!**
Accept the license at:
["runwayml/stable-diffusion-v1-5"](https://huggingface.co/runwayml/stable-diffusion-v1-5)
It's the new default model.
* **Official Stable Diffusion inpainting model**
Accept the license at:
["runwayml/stable-diffusion-inpainting"](https://huggingface.co/runwayml/stable-diffusion-inpainting),
A few big caveats!
1) Different model - so back to a separate container for inpainting, also because:
2) New pipeline that can't share model struct with other pipelines
(see [diffusers#920](https://github.com/huggingface/diffusers/issues/920)).
3) Old pipeline is now called `StableDiffusionInpaintPipelineLegacy` (for sd-1.4)
4) `model_input` takes `image` now, and not `init_image` like the legacy model.
5) There is no `strength` parameter in the new model
(see [diffusers#920](https://github.com/huggingface/diffusers/issues/920)).
* Upgrade to **Diffusers v0.7.0.dev0**
* **Flash attention** now disabled by default. 1) Because it's built on
an older version of diffusers, but also because 2) I didn't succeed in
getting much improvement out of it. Maybe someone else will have better
luck. I think you need big batch sizes to really see the benefit, which
doesn't suit my use case. But please anyone who figures anything out,
let us know.
================================================
FILE: CONTRIBUTING.md
================================================
# CONTRIBUTING
*Tips for development*
1. [General Hints](#general)
1. [Development / Editor Setup](#editors)
1. [Visual Studio Code (vscode)](#vscode)
1. [Testing](#testing)
1. [Using Buildkit](#buildkit)
1. [Local HTTP(S) Caching Proxy](#caching)
1. [Local S3 Server](#local-s3-server)
1. [Stop on Suspend](#stop-on-suspend)
<a name="general"></a>
## General
1. Run docker with `-it` to make it easier to stop container with `Ctrl-C`.
1. If you get a `CUDA initialization: CUDA unknown error` after suspend,
just stop the container, `rmmod nvidia_uvm`, and restart.
<a name="editors"></a>
## Editors
<a name="vscode"></a>
### Visual Studio Code (recommended, WIP)
*We're still writing this guide, let us know of any needed improvements*
This repo includes VSCode settings that allow for a) editing inside a docker container, b) tests and coverage (on save)
1. Install from https://code.visualstudio.com/
1. Install [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension.
1. Open your docker-diffusers-api folder, you'll get a popup in the bottom right that a dev container environment was detected, click "reload in container"
1. Look for the "( ) Watch" on status bar and click it so it changes to "( ) XX Coverage"
**Live Development**
1. **Run Task** (either Ctrl-Shift-P and "Run Task", or in Terminals, the Plus ("+") DROPDOWN selector and choose, "Run Task..." at the bottom)
1. Choose **Watching Server**. Port 8000 will be forwarded. The server will be reloaded
on every file safe (make sure to give it enough time to fully load before sending another
request, otherwise that request will hang).
<a name="testing"></a>
## Testing
1. **Unit testing**: exists but is sorely lacking for now. If you use the
recommended editor setup above, it's probably working already. However:
1. **Integation / E2E**: cover most features used in production.
`pytest -s tests/integration`.
The `-s` is optional but streams stdout so you can follow along.
Add also `-k test_name` to test a specific test. E2E tests are LONG but you can
greatly reduce subsequent run time by following the steps below for a
[Local HTTP(S) Caching Proxy](#caching) and [Local S3 Server](#local-s3-server).
Docker-Diffusers-API follows Semantic Versioning. We follow the
[conventional commits](https://www.conventionalcommits.org/en/v1.0.0/)
standard.
* On a commit to `dev`, if all CI tests pass, a new release is made to `:dev` tag.
* On a commit to `main`, if all CI tests pass, a new release with appropriate
major / minor / patch is made, based on appropriate tags in the commit history.
<a name="buildkit"></a>
## Using BuildKit
Buildkit is a docker extension that can really improve build speeds through
caching and parallelization. You can enable and tweak it by adding:
`DOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain`
vars before `docker build` (the `PROGRESS` var shows much more detailed
build logs, which can be useful, but are much more verbose). This is
already all setup in the the [build](./build) script.
<a name="caching"></a>
## Local HTTP(S) Caching Proxy
If you're only editing e.g. `app.py`, there's no need to worry about caching
and the docker layers work amazingly. But, if you're constantly changing
installed packages (apt, `requirements.txt`), `download.py`, etc, it's VERY
helpful to have a local cache:
```bash
# See all options at https://hub.docker.com/r/gadicc/squid-ssl-zero
$ docker run -d -p 3128:3128 -p 3129:80 \
--name squid --restart=always \
-v /usr/local/squid:/usr/local/squid \
gadicc/squid-ssl-zero
```
and then set the docker build args `proxy=1`, and `http_proxy` / `https_proxy`
with their respective values.
This is already all set up in the [build](./build) script.
**You probably want to fine-tune /usr/local/squid/etc/squid.conf**.
It will be created after you first run `gadicc/squid-ssl-zero`. You can then
stop the container (`docker ps`, `docker stop container_id`), edit the file,
and re-start (`docker start container_id`). For now, try something like:
```conf
cache_dir ufs /usr/local/squid/cache 50000 16 256 # 50GB
maximum_object_size 20 GB
refresh_pattern . 52034400 50% 52034400 store-stale override-expire ignore-no-cache ignore-no-store ignore-private
```
but ideally we can as a community create some rules that don't so
aggressively catch every single request.
<a name="local-s3"></a>
## Local S3 server
If you're doing development around the S3 handling, it can be very useful to
have a local S3 server, especially due to the large size of models. You
can set one up like this:
```bash
$ docker run -p 9000:9000 -p 9001:9001 \
-v /usr/local/minio:/data quay.io/minio/minio \
server /data --console-address ":9001"
```
Now point a web browser to http://localhost:9001/, login with the default
root credentials `minioadmin:minioadmin` and create a bucket and credentials
for testing. More info at https://hub.docker.com/r/minio/minio/.
Typical policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "VisualEditor0",
"Effect": "Allow",
"Action": [
"s3:PutObject",
"s3:GetObject"
],
"Resource": "arn:aws:s3:::BUCKET_NAME/*"
}
]
}
```
Then set the **build-arg** `AWS_S3_ENDPOINT_URL="http://172.17.0.1:9000"`
or as appropriate if you've changed the default docker network.
<a name="stop-on-suspend"></a>
## Stop on Suspend
Maybe it's just me, but frequently I'll have issues when suspending with
the container running (I guess its a CUDA issue), either a freeze on resume,
or a stuck-forever defunct process. I found it useful to automatically stop
the container / process on suspend.
I'm running ArchLinux and set up a `systemd` suspend hook as described
[here](https://wiki.archlinux.org/title/Power_management#Sleep_hooks), to
call a script, which contains:
```bash
# Stop a matching docker container
PID=`docker ps -qf ancestor=gadicc/diffusers-api`
if [ ! -z $PID ] ; then
echo "Stopping diffusers-api pid $PID"
docker stop $PID
fi
# For a VSCode devcontainer, just kill the watchmedo process.
PID=`docker ps -qf volume=/home/dragon/root-cache`
if [ ! -z $PID ] ; then
echo "Stopping watchmedo in container $PID"
docker exec $PID /bin/bash -c 'kill `pidof -sx watchmedo`'
fi
```
================================================
FILE: Dockerfile
================================================
ARG FROM_IMAGE="pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime"
# ARG FROM_IMAGE="gadicc/diffusers-api-base:python3.9-pytorch1.12.1-cuda11.6-xformers"
# You only need the -banana variant if you need banana's optimization
# i.e. not relevant if you're using RUNTIME_DOWNLOADS
# ARG FROM_IMAGE="gadicc/python3.9-pytorch1.12.1-cuda11.6-xformers-banana"
FROM ${FROM_IMAGE} as base
ENV FROM_IMAGE=${FROM_IMAGE}
# Note, docker uses HTTP_PROXY and HTTPS_PROXY (uppercase)
# We purposefully want those managed independently, as we want docker
# to manage its own cache. This is just for pip, models, etc.
ARG http_proxy
ARG https_proxy
RUN if [ -n "$http_proxy" ] ; then \
echo quit \
| openssl s_client -proxy $(echo ${https_proxy} | cut -b 8-) -servername google.com -connect google.com:443 -showcerts \
| sed 'H;1h;$!d;x; s/^.*\(-----BEGIN CERTIFICATE-----.*-----END CERTIFICATE-----\)\n---\nServer certificate.*$/\1/' \
> /usr/local/share/ca-certificates/squid-self-signed.crt ; \
update-ca-certificates ; \
fi
ARG REQUESTS_CA_BUNDLE=${http_proxy:+/usr/local/share/ca-certificates/squid-self-signed.crt}
ARG DEBIAN_FRONTEND=noninteractive
ARG TZ=UTC
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN apt-get update
RUN apt-get install -yq apt-utils
RUN apt-get install -yqq git zstd wget curl
FROM base AS patchmatch
ARG USE_PATCHMATCH=0
WORKDIR /tmp
COPY scripts/patchmatch-setup.sh .
RUN sh patchmatch-setup.sh
FROM base as output
RUN mkdir /api
WORKDIR /api
# we use latest pip in base image
# RUN pip3 install --upgrade pip
ADD requirements.txt requirements.txt
RUN pip install -r requirements.txt
# [Import] Add missing settings / Correct some dummy imports (#5036) - 2023-09-14
ARG DIFFUSERS_VERSION="3aa641289c995b3a0ce4ea895a76eb1128eff30c"
ENV DIFFUSERS_VERSION=${DIFFUSERS_VERSION}
RUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout ${DIFFUSERS_VERSION}
WORKDIR /api
RUN pip install -e diffusers
# Set to true to NOT download model at build time, rather at init / usage.
ARG RUNTIME_DOWNLOADS=1
ENV RUNTIME_DOWNLOADS=${RUNTIME_DOWNLOADS}
# TODO, to dda-bananana
# ARG PIPELINE="StableDiffusionInpaintPipeline"
ARG PIPELINE="ALL"
ENV PIPELINE=${PIPELINE}
# Deps for RUNNING (not building) earlier options
ARG USE_PATCHMATCH=0
RUN if [ "$USE_PATCHMATCH" = "1" ] ; then apt-get install -yqq python3-opencv ; fi
COPY --from=patchmatch /tmp/PyPatchMatch PyPatchMatch
# TODO, just include by default, and handle all deps in OUR requirements.txt
ARG USE_DREAMBOOTH=1
ENV USE_DREAMBOOTH=${USE_DREAMBOOTH}
RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then \
# By specifying the same torch version as conda, it won't download again.
# Without this, it will upgrade torch, break xformers, make bigger image.
# bitsandbytes==0.40.0.post4 had failed cuda detection on dreambooth test.
pip install -r diffusers/examples/dreambooth/requirements.txt ; \
fi
RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then apt-get install -yqq git-lfs ; fi
ARG USE_REALESRGAN=1
RUN if [ "$USE_REALESRGAN" = "1" ] ; then apt-get install -yqq libgl1-mesa-glx libglib2.0-0 ; fi
RUN if [ "$USE_REALESRGAN" = "1" ] ; then git clone https://github.com/xinntao/Real-ESRGAN.git ; fi
# RUN if [ "$USE_REALESRGAN" = "1" ] ; then pip install numba==0.57.1 chardet ; fi
RUN if [ "$USE_REALESRGAN" = "1" ] ; then pip install basicsr==1.4.2 facexlib==0.2.5 gfpgan==1.3.8 ; fi
RUN if [ "$USE_REALESRGAN" = "1" ] ; then cd Real-ESRGAN && python3 setup.py develop ; fi
COPY api/ .
EXPOSE 8000
ARG SAFETENSORS_FAST_GPU=1
ENV SAFETENSORS_FAST_GPU=${SAFETENSORS_FAST_GPU}
CMD python3 -u server.py
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Banana, Gadi Cohen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# docker-diffusers-api ("banana-sd-base")
Diffusers / Stable Diffusion in docker with a REST API, supporting various models, pipelines & schedulers. Used by [kiri.art](https://kiri.art/), perfect for local, server & serverless.
[](https://hub.docker.com/r/gadicc/diffusers-api/tags) [](https://circleci.com/gh/kiri-art/docker-diffusers-api?branch=split) [](https://github.com/semantic-release/semantic-release) [](./LICENSE) [](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/kiri-art/docker-diffusers-api)
Copyright (c) Gadi Cohen, 2022. MIT Licensed.
Please give credit and link back to this repo if you use it in a public project.
## Features
* Models: stable-diffusion, waifu-diffusion, and easy to add others (e.g. jp-sd)
* Pipelines: txt2img, img2img and inpainting in a single container
([all diffusers official and community pipelines](https://forums.kiri.art/t/all-your-pipelines-are-belong-to-us/83) are wrapped, but untested)
* All model inputs supported, including setting nsfw filter per request
* *Permute* base config to multiple forks based on yaml config with vars
* Optionally send signed event logs / performance data to a REST endpoint / webhook.
* Can automatically download a checkpoint file and convert to diffusers.
* S3 support, dreambooth training.
Note: This image was created for [kiri.art](https://kiri.art/).
Everything is open source but there may be certain request / response
assumptions. If anything is unclear, please open an issue.
## Important Notices
* [Official `docker-diffusers-api` Forum](https://forums.kiri.art/c/docker-diffusers-api/16):
help, updates, discussion.
* Subscribe ("watch") these forum topics for:
* [notable **`main`** branch updates](https://forums.kiri.art/t/official-releases-main-branch/35)
* [notable **`dev`** branch updates](https://forums.kiri.art/t/development-releases-dev-branch/53)
* Always [check the CHANGELOG](./CHANGELOG.md) for important updates when upgrading.
**Official help in our dedicated forum https://forums.kiri.art/c/docker-diffusers-api/16.**
**This README refers to the in-development `dev` branch** and may
reference features and fixes not yet in the published releases.
**`v1` has not yet been officially released yet** but has been
running well in production on kiri.art for almost a month. We'd
be grateful for any feedback from early adopters to help make
this official. For more details, see [Upgrading from v0 to
v1](https://forums.kiri.art/t/wip-upgrading-from-v0-to-v1/116).
Previous releases available on the `dev-v0-final` and
`main-v0-final` branches.
**Currently only NVIDIA / CUDA devices are supported**. Tracking
Apple / M1 support in issue
[#20](https://github.com/kiri-art/docker-diffusers-api/issues/20).
## Installation & Setup:
Setup varies depending on your use case.
1. **To run locally or on a *server*, with runtime downloads:**
`docker run --gpus all -p 8000:8000 -e HF_AUTH_TOKEN=$HF_AUTH_TOKEN gadicc/diffusers-api`.
See the [guides for various cloud providers](https://forums.kiri.art/t/running-on-other-cloud-providers/89/7).
1. **To run *serverless*, include the model at build time:**
1. [docker-diffusers-api-build-download](https://github.com/kiri-art/docker-diffusers-api-build-download) (
[banana](https://forums.kiri.art/t/run-diffusers-api-on-banana-dev/103), others)
1. [docker-diffusers-api-runpod](https://github.com/kiri-art/docker-diffusers-api-runpod),
see the [guide](https://forums.kiri.art/t/run-diffusers-api-on-runpod-io/102)
1. **Building from source**.
1. Fork / clone this repo.
1. `docker build -t gadicc/diffusers-api .`
1. See [CONTRIBUTING.md](./CONTRIBUTING.md) for more helpful hints.
*Other configurations are possible but these are the most common cases*
Everything is set via docker build-args or environment variables.
## Usage:
See also [Testing](#testing) below.
The container expects an `HTTP POST` request to `/`, with a JSON body resembling the following:
```json
{
"modelInputs": {
"prompt": "Super dog",
"num_inference_steps": 50,
"guidance_scale": 7.5,
"width": 512,
"height": 512,
"seed": 3239022079
},
"callInputs": {
// You can leave these out to use the default
"MODEL_ID": "runwayml/stable-diffusion-v1-5",
"PIPELINE": "StableDiffusionPipeline",
"SCHEDULER": "LMSDiscreteScheduler",
"safety_checker": true,
},
}
```
It's important to remember that `docker-diffusers-api` is primarily a wrapper
around HuggingFace's
[diffusers](https://huggingface.co/docs/diffusers/index) library.
**Basic familiarity with `diffusers` is indespensible for a good experience
with `docker-diffusers-api`.** Explaining some of the options above:
* **modelInputs** - for the most part - are passed directly to the selected
diffusers pipeline unchanged. So, for the default `StableDiffusionPipeline`,
you can see all options in the relevant pipeline docs for its
[`__call__`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__) method. The main exceptions are:
* Only valid JSON values can be given (strings, numbers, etc)
* **seed**, a number, is transformed into a `generator`.
* **images** are converted to / from base64 encoded strings.
* **callInputs** affect which model, pipeline, scheduler and other lower
level options are used to construct the final pipeline. Notably:
* **`SCHEDULER`**: any scheduler included in diffusers should work out
the box, provided it can loaded with its default config and without
requiring any other explicit arguments at init time. In any event,
the following schedulers are the most common and most well tested:
`DPMSolverMultistepScheduler` (fast! only needs 20 steps!),
`LMSDiscreteScheduler`, `DDIMScheduler`, `PNDMScheduler`,
`EulerAncestralDiscreteScheduler`, `EulerDiscreteScheduler`.
* **`PIPELINE`**: the most common are
[`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img),
[`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img),
[`StableDiffusionInpaintPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint), and the community
[`lpw_stable_diffusion`](https://forums.kiri.art/t/lpw-stable-diffusion-pipeline-longer-prompts-prompt-weights/82)
which allows for long prompts (more than 77 tokens) and prompt weights
(things like `((big eyes))`, `(red hair:1.2)`, etc), and accepts a
`custom_pipeline_method` callInput with values `text2img` ("text", not "txt"),
`img2img` and `inpaint`. See these links for all the possible `modelInputs`'s
that can be passed to the pipeline's `__call__` method.
* **`MODEL_URL`** (optional) can be used to retrieve the model from
locations other than HuggingFace, e.g. an `HTTP` server, S3-compatible
storage, etc. For more info, see the
[storage docs](https://github.com/kiri-art/docker-diffusers-api/blob/dev/docs/storage.md)
and
[this post](https://forums.kiri.art/t/safetensors-our-own-optimization-faster-model-init/98)
for info on how to use and store optimized models from your own cloud.
<a name="testing"></a>
## Examples and testing
There are also very basic examples in [test.py](./test.py), which you can view
and call `python test.py` if the container is already running on port 8000.
You can also specify a specific test, change some options, and run against a
deployed banana image:
```bash
$ python test.py
Usage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc]
# Run against http://localhost:8000/ (Nvidia Quadro RTX 5000)
$ python test.py txt2img
Running test: txt2img
Request took 5.9s (init: 3.2s, inference: 5.9s)
Saved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png
# Run against deployed banana image (Nvidia A100)
$ export BANANA_API_KEY=XXX
$ BANANA_MODEL_KEY=XXX python3 test.py --banana txt2img
Running test: txt2img
Request took 19.4s (init: 2.5s, inference: 3.5s)
Saved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png
# Note that 2nd runs are much faster (ignore init, that isn't run again)
Request took 3.0s (init: 2.4s, inference: 2.1s)
```
The best example of course is https://kiri.art/ and it's
[source code](https://github.com/kiri-art/stable-diffusion-react-nextjs-mui-pwa).
## Help on [Official Forums](https://forums.kiri.art/c/docker-diffusers-api/16).
## Adding other Models
You have two options.
1. For a diffusers model, simply set `MODEL_ID` build-var / call-arg to the name
of the model hosted on HuggingFace, and it will be downloaded automatically at
build time.
1. For a non-diffusers model, simply set the `CHECKPOINT_URL` build-var / call-arg
to the URL of a `.ckpt` file, which will be downloaded and converted to the diffusers
format automatically at build time. `CHECKPOINT_CONFIG_URL` can also be set.
## Troubleshooting
* **403 Client Error: Forbidden for url**
Make sure you've accepted the license on the model card of the HuggingFace model
specified in `MODEL_ID`, and that you correctly passed `HF_AUTH_TOKEN` to the
container.
## Event logs / web hooks / performance data
Set `SEND_URL` (and optionally `SIGN_KEY`) environment variable(s) to send
event and timing data on `init`, `inference` and other start and end events.
This can either be used to log performance data, or for webhooks on event
start / finish.
The timing data is now returned in the response payload too, like this:
`{ $timings: { init: timeInMs, inference: timeInMs } }`, with any other
events (such a `training`, `upload`, etc).
You can go to https://webhook.site/ and use the provided "unique URL"
as your `SEND_URL` to see how it works, if you don't have your own
REST endpoint (yet).
If `SIGN_KEY` is used, you can verify the signature like this (TypeScript):
```ts
import crypto from "crypto";
async function handler(req: NextApiRequest, res: NextApiResponse) {
const data = req.body;
const containerSig = data.sig as string;
delete data.sig;
const ourSig = crypto
.createHash("md5")
.update(JSON.stringify(data) + process.env.SIGN_KEY)
.digest("hex");
const signatureIsValid = containerSig === ourSig;
}
```
If you send a callInput called `startRequestId`, it will get sent
back as part of the send payload in most cases.
You can also set callInputs `SEND_URL` and `SIGN_KEY` to
set or override these values on a per-request basis.
## Acknowledgements
* The container image is originally based on
https://github.com/bananaml/serverless-template-stable-diffusion.
* [CompVis](https://github.com/CompVis),
[Stability AI](https://stability.ai/),
[LAION](https://laion.ai/)
and [RunwayML](https://runwayml.com/)
for their incredible time, work and efforts in creating Stable Diffusion,
and no less so, their decision to release it publicly with an open source
license.
* [HuggingFace](https://huggingface.co/) - for their passion and inspiration
for making machine learning more accessibe to developers, and in particular,
their [Diffusers](https://github.com/huggingface/diffusers) library.
================================================
FILE: __init__.py
================================================
================================================
FILE: api/app.py
================================================
import asyncio
from sched import scheduler
import torch
from torch import autocast
from diffusers import __version__
import base64
from io import BytesIO
import PIL
import json
from loadModel import loadModel
from send import send, getTimings, clearSession
from status import status
import os
import numpy as np
import skimage
import skimage.measure
from getScheduler import getScheduler, SCHEDULERS
from getPipeline import (
getPipelineClass,
getPipelineForModel,
listAvailablePipelines,
clearPipelines,
)
import re
import requests
from download import download_model, normalize_model_id
import traceback
from precision import MODEL_REVISION, MODEL_PRECISION
from device import device, device_id, device_name
from utils import Storage
from hashlib import sha256
from threading import Timer
import extras
import jxlpy
from jxlpy import JXLImagePlugin
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
pipelines as diffusers_pipelines,
AutoencoderTiny,
AutoencoderKL,
)
from lib.textual_inversions import handle_textual_inversions
from lib.prompts import prepare_prompts
from lib.vars import (
RUNTIME_DOWNLOADS,
USE_DREAMBOOTH,
MODEL_ID,
PIPELINE,
HF_AUTH_TOKEN,
HOME,
MODELS_DIR,
)
if USE_DREAMBOOTH:
from train_dreambooth import TrainDreamBooth
print(os.environ.get("USE_PATCHMATCH"))
if os.environ.get("USE_PATCHMATCH") == "1":
from PyPatchMatch import patch_match
torch.set_grad_enabled(False)
always_normalize_model_id = None
tiny_vae = None
# still working on this, not in use yet.
def tinyVae(origVae: AutoencoderKL):
global tiny_vae
if not tiny_vae:
tiny_vae = AutoencoderTiny.from_pretrained(
"madebyollin/taesd",
torch_dtype=torch.float16,
in_channels=origVae.config.in_channels,
out_channels=origVae.config.out_channels,
act_fn=origVae.config.act_fn,
latent_channels=origVae.config.latent_channels,
scaling_factor=origVae.config.scaling_factor,
force_upcast=origVae.config.force_upcast,
)
tiny_vae.to("cuda")
return tiny_vae
# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
global model # needed for bananna optimizations
global always_normalize_model_id
asyncio.run(
send(
"init",
"start",
{
"device": device_name,
"hostname": os.getenv("HOSTNAME"),
"model_id": MODEL_ID,
"diffusers": __version__,
},
)
)
if MODEL_ID == "ALL" or RUNTIME_DOWNLOADS:
global last_model_id
last_model_id = None
if not RUNTIME_DOWNLOADS:
normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
if os.path.isdir(model_dir):
always_normalize_model_id = model_dir
else:
normalized_model_id = MODEL_ID
model = loadModel(
model_id=always_normalize_model_id or MODEL_ID,
load=True,
precision=MODEL_PRECISION,
revision=MODEL_REVISION,
)
else:
model = None
asyncio.run(send("init", "done"))
def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
image = PIL.Image.open(BytesIO(base64.decodebytes(bytes(imageStr, "utf-8"))))
print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}')
return image
def getFromUrl(url: str, name: str) -> PIL.Image:
response = requests.get(url)
image = PIL.Image.open(BytesIO(response.content))
print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}')
return image
def truncateInputs(inputs: dict):
clone = inputs.copy()
if "modelInputs" in clone:
modelInputs = clone["modelInputs"] = clone["modelInputs"].copy()
for item in ["init_image", "mask_image", "image", "input_image"]:
if item in modelInputs:
modelInputs[item] = modelInputs[item][0:6] + "..."
if "instance_images" in modelInputs:
modelInputs["instance_images"] = list(
map(lambda str: str[0:6] + "...", modelInputs["instance_images"])
)
return clone
# last_xformers_memory_efficient_attention = {}
last_attn_procs = None
last_lora_weights = None
cross_attention_kwargs = None
# Inference is ran for every server call
# Reference your preloaded global model variable here.
async def inference(all_inputs: dict, response) -> dict:
global model
global pipelines
global last_model_id
global schedulers
# global last_xformers_memory_efficient_attention
global always_normalize_model_id
global last_attn_procs
global last_lora_weights
global cross_attention_kwargs
clearSession()
print(json.dumps(truncateInputs(all_inputs), indent=2))
model_inputs = all_inputs.get("modelInputs", None)
call_inputs = all_inputs.get("callInputs", None)
result = {"$meta": {}}
send_opts = {}
if call_inputs.get("SEND_URL", None):
send_opts.update({"SEND_URL": call_inputs.get("SEND_URL")})
if call_inputs.get("SIGN_KEY", None):
send_opts.update({"SIGN_KEY": call_inputs.get("SIGN_KEY")})
if response:
send_opts.update({"response": response})
async def sendStatusAsync():
await response.send(json.dumps(status.get()) + "\n")
def sendStatus():
try:
asyncio.run(sendStatusAsync())
Timer(1.0, sendStatus).start()
except:
pass
Timer(1.0, sendStatus).start()
if model_inputs == None or call_inputs == None:
return {
"$error": {
"code": "INVALID_INPUTS",
"message": "Expecting on object like { modelInputs: {}, callInputs: {} } but got "
+ json.dumps(all_inputs),
}
}
startRequestId = call_inputs.get("startRequestId", None)
use_extra = call_inputs.get("use_extra", None)
if use_extra:
extra = getattr(extras, use_extra, None)
if not extra:
return {
"$error": {
"code": "NO_SUCH_EXTRA",
"message": 'Requested "'
+ use_extra
+ '", available: "'
+ '", "'.join(extras.keys())
+ '"',
}
}
return await extra(
model_inputs,
call_inputs,
send_opts=send_opts,
startRequestId=startRequestId,
)
model_id = call_inputs.get("MODEL_ID", None)
if not model_id:
if not MODEL_ID:
return {
"$error": {
"code": "NO_MODEL_ID",
"message": "No callInputs.MODEL_ID specified, nor was MODEL_ID env var set.",
}
}
model_id = MODEL_ID
result["$meta"].update({"MODEL_ID": MODEL_ID})
normalized_model_id = model_id
if RUNTIME_DOWNLOADS:
hf_model_id = call_inputs.get("HF_MODEL_ID", None)
model_revision = call_inputs.get("MODEL_REVISION", None)
model_precision = call_inputs.get("MODEL_PRECISION", None)
checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
normalized_model_id = normalize_model_id(model_id, model_revision)
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
pipeline_name = call_inputs.get("PIPELINE", None)
if pipeline_name:
pipeline_class = getPipelineClass(pipeline_name)
if last_model_id != normalized_model_id:
# if not downloaded_models.get(normalized_model_id, None):
if not os.path.isdir(model_dir):
model_url = call_inputs.get("MODEL_URL", None)
if not model_url:
# return {
# "$error": {
# "code": "NO_MODEL_URL",
# "message": "Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput",
# }
# }
normalized_model_id = hf_model_id or model_id
await download_model(
model_id=model_id,
model_url=model_url,
model_revision=model_revision,
checkpoint_url=checkpoint_url,
checkpoint_config_url=checkpoint_config_url,
hf_model_id=hf_model_id,
model_precision=model_precision,
send_opts=send_opts,
pipeline_class=pipeline_class if pipeline_name else None,
)
# downloaded_models.update({normalized_model_id: True})
clearPipelines()
cross_attention_kwargs = None
if model:
model.to("cpu") # Necessary to avoid a memory leak
await send(
"loadModel", "start", {"startRequestId": startRequestId}, send_opts
)
model = await asyncio.to_thread(
loadModel,
model_id=normalized_model_id,
load=True,
precision=model_precision,
revision=model_revision,
send_opts=send_opts,
pipeline_class=pipeline_class if pipeline_name else None,
)
await send(
"loadModel", "done", {"startRequestId": startRequestId}, send_opts
)
last_model_id = normalized_model_id
last_attn_procs = None
last_lora_weights = None
else:
if always_normalize_model_id:
normalized_model_id = always_normalize_model_id
print(
{
"always_normalize_model_id": always_normalize_model_id,
"normalized_model_id": normalized_model_id,
}
)
if MODEL_ID == "ALL":
if last_model_id != normalized_model_id:
clearPipelines()
cross_attention_kwargs = None
model = loadModel(normalized_model_id, send_opts=send_opts)
last_model_id = normalized_model_id
else:
if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
return {
"$error": {
"code": "MODEL_MISMATCH",
"message": f'Model "{model_id}" not available on this container which hosts "{MODEL_ID}"',
"requested": model_id,
"available": MODEL_ID,
}
}
if PIPELINE == "ALL":
pipeline_name = call_inputs.get("PIPELINE", None)
if not pipeline_name:
pipeline_name = "AutoPipelineForText2Image"
result["$meta"].update({"PIPELINE": pipeline_name})
pipeline = getPipelineForModel(
pipeline_name,
model,
normalized_model_id,
model_revision=model_revision if RUNTIME_DOWNLOADS else MODEL_REVISION,
model_precision=model_precision if RUNTIME_DOWNLOADS else MODEL_PRECISION,
)
if not pipeline:
return {
"$error": {
"code": "NO_SUCH_PIPELINE",
"message": f'"{pipeline_name}" is not an official nor community Diffusers pipelines',
"requested": pipeline_name,
"available": listAvailablePipelines(),
}
}
else:
pipeline = model
scheduler_name = call_inputs.get("SCHEDULER", None)
if not scheduler_name:
scheduler_name = "DPMSolverMultistepScheduler"
result["$meta"].update({"SCHEDULER": scheduler_name})
pipeline.scheduler = getScheduler(normalized_model_id, scheduler_name)
if pipeline.scheduler == None:
return {
"$error": {
"code": "INVALID_SCHEDULER",
"message": "",
"requeted": call_inputs.get("SCHEDULER", None),
"available": ", ".join(SCHEDULERS),
}
}
safety_checker = call_inputs.get("safety_checker", True)
pipeline.safety_checker = (
model.safety_checker
if safety_checker and hasattr(model, "safety_checker")
else None
)
is_url = call_inputs.get("is_url", False)
image_decoder = getFromUrl if is_url else decodeBase64Image
textual_inversions = call_inputs.get("textual_inversions", [])
await handle_textual_inversions(textual_inversions, model, status=status)
# Better to use new lora_weights in next section
attn_procs = call_inputs.get("attn_procs", None)
if attn_procs is not last_attn_procs:
if attn_procs:
raise Exception(
"[REMOVED] Using `attn_procs` for LoRAs is no longer supported. "
+ "Please use `lora_weights` instead."
)
last_attn_procs = attn_procs
# if attn_procs:
# storage = Storage(attn_procs, no_raise=True)
# if storage:
# hash = sha256(attn_procs.encode("utf-8")).hexdigest()
# attn_procs_from_safetensors = call_inputs.get(
# "attn_procs_from_safetensors", None
# )
# fname = storage.url.split("/").pop()
# if attn_procs_from_safetensors and not re.match(
# r".safetensors", attn_procs
# ):
# fname += ".safetensors"
# if True:
# # TODO, way to specify explicit name
# path = os.path.join(
# MODELS_DIR, "attn_proc--url_" + hash[:7] + "--" + fname
# )
# attn_procs = path
# if not os.path.exists(path):
# storage.download_and_extract(path)
# print("Load attn_procs " + attn_procs)
# # Workaround https://github.com/huggingface/diffusers/pull/2448#issuecomment-1453938119
# if storage and not re.search(r".safetensors", attn_procs):
# attn_procs = torch.load(attn_procs, map_location="cpu")
# pipeline.unet.load_attn_procs(attn_procs)
# else:
# print("Clearing attn procs")
# pipeline.unet.set_attn_processor(CrossAttnProcessor())
# Currently we only support a single string, but we should allow
# and array too in anticipation of multi-LoRA support in diffusers
# tracked at https://github.com/huggingface/diffusers/issues/2613.
lora_weights = call_inputs.get("lora_weights", None)
lora_weights_joined = json.dumps(lora_weights)
if last_lora_weights != lora_weights_joined:
if last_lora_weights != None and last_lora_weights != "[]":
print("Unloading previous LoRA weights")
pipeline.unload_lora_weights()
last_lora_weights = lora_weights_joined
cross_attention_kwargs = {}
if type(lora_weights) is not list:
lora_weights = [lora_weights] if lora_weights else []
if len(lora_weights) > 0:
for weights in lora_weights:
storage = Storage(weights, no_raise=True, status=status)
if storage:
storage_query_fname = storage.query.get("fname")
storage_query_scale = (
float(storage.query.get("scale")[0])
if storage.query.get("scale")
else 1
)
cross_attention_kwargs.update({"scale": storage_query_scale})
# https://github.com/damian0815/compel/issues/42#issuecomment-1656989385
pipeline._lora_scale = storage_query_scale
if storage_query_fname:
fname = storage_query_fname[0]
else:
hash = sha256(weights.encode("utf-8")).hexdigest()
fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop()
cache_fname = "lora_weights--" + fname
path = os.path.join(MODELS_DIR, cache_fname)
if not os.path.exists(path):
await asyncio.to_thread(storage.download_file, path)
print("Load lora_weights `" + weights + "` from `" + path + "`")
pipeline.load_lora_weights(
MODELS_DIR, weight_name=cache_fname, local_files_only=True
)
else:
print("Loading from huggingface not supported yet: " + weights)
# maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
# lora_model_id = "sayakpaul/civitai-light-shadow-lora"
# lora_filename = "light_and_shadow.safetensors"
# pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
else:
print("No changes to LoRAs since last call")
# TODO, generalize
mi_cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
if mi_cross_attention_kwargs:
model_inputs.pop("cross_attention_kwargs")
if isinstance(mi_cross_attention_kwargs, str):
if not cross_attention_kwargs:
cross_attention_kwargs = {}
cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs))
elif type(mi_cross_attention_kwargs) == dict:
if not cross_attention_kwargs:
cross_attention_kwargs = {}
cross_attention_kwargs.update(mi_cross_attention_kwargs)
else:
return {
"$error": {
"code": "INVALID_CROSS_ATTENTION_KWARGS",
"message": "`cross_attention_kwargs` should be a dict or json string",
}
}
print({"cross_attention_kwargs": cross_attention_kwargs})
if cross_attention_kwargs:
model_inputs.update({"cross_attention_kwargs": cross_attention_kwargs})
# Parse out your arguments
# prompt = model_inputs.get("prompt", None)
# if prompt == None:
# return {"message": "No prompt provided"}
#
# height = model_inputs.get("height", 512)
# width = model_inputs.get("width", 512)
# num_inference_steps = model_inputs.get("num_inference_steps", 50)
# guidance_scale = model_inputs.get("guidance_scale", 7.5)
# seed = model_inputs.get("seed", None)
# strength = model_inputs.get("strength", 0.75)
if "init_image" in model_inputs:
model_inputs["init_image"] = image_decoder(
model_inputs.get("init_image"), "init_image"
)
if "image" in model_inputs:
model_inputs["image"] = image_decoder(model_inputs.get("image"), "image")
if "mask_image" in model_inputs:
model_inputs["mask_image"] = image_decoder(
model_inputs.get("mask_image"), "mask_image"
)
if "instance_images" in model_inputs:
model_inputs["instance_images"] = list(
map(
lambda str: image_decoder(str, "instance_image"),
model_inputs["instance_images"],
)
)
await send("inference", "start", {"startRequestId": startRequestId}, send_opts)
# Run patchmatch for inpainting
if call_inputs.get("FILL_MODE", None) == "patchmatch":
sel_buffer = np.array(model_inputs.get("init_image"))
img = sel_buffer[:, :, 0:3]
mask = sel_buffer[:, :, -1]
img = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
model_inputs["init_image"] = PIL.Image.fromarray(img)
mask = 255 - mask
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
model_inputs["mask_image"] = PIL.Image.fromarray(mask)
# Turning on takes 3ms and turning off 1ms... don't worry, I've got your back :)
# x_m_e_a = call_inputs.get("xformers_memory_efficient_attention", True)
# last_x_m_e_a = last_xformers_memory_efficient_attention.get(pipeline, None)
# if x_m_e_a != last_x_m_e_a:
# if x_m_e_a == True:
# print("pipeline.enable_xformers_memory_efficient_attention()")
# pipeline.enable_xformers_memory_efficient_attention() # default on
# elif x_m_e_a == False:
# print("pipeline.disable_xformers_memory_efficient_attention()")
# pipeline.disable_xformers_memory_efficient_attention()
# else:
# return {
# "$error": {
# "code": "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE",
# "message": f"x_m_e_a expects True or False, not: {x_m_e_a}",
# "requested": x_m_e_a,
# "available": [True, False],
# }
# }
# last_xformers_memory_efficient_attention.update({pipeline: x_m_e_a})
# Run the model
# with autocast(device_id):
# image = pipeline(**model_inputs).images[0]
if call_inputs.get("train", None) == "dreambooth":
if not USE_DREAMBOOTH:
return {
"$error": {
"code": "TRAIN_DREAMBOOTH_NOT_AVAILABLE",
"message": 'Called with callInput { train: "dreambooth" } but built with USE_DREAMBOOTH=0',
}
}
if RUNTIME_DOWNLOADS:
if os.path.isdir(model_dir):
normalized_model_id = model_dir
torch.set_grad_enabled(True)
result = result | await asyncio.to_thread(
TrainDreamBooth,
normalized_model_id,
pipeline,
model_inputs,
call_inputs,
send_opts=send_opts,
)
torch.set_grad_enabled(False)
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
result.update({"$timings": getTimings()})
return result
# Do this after dreambooth as dreambooth accepts a seed int directly.
seed = model_inputs.get("seed", None)
if seed == None:
generator = torch.Generator(device=device)
generator.seed()
else:
generator = torch.Generator(device=device).manual_seed(seed)
del model_inputs["seed"]
model_inputs.update({"generator": generator})
callback = None
if model_inputs.get("callback_steps", None):
def callback(step: int, timestep: int, latents: torch.FloatTensor):
asyncio.run(
send(
"inference",
"progress",
{"startRequestId": startRequestId, "step": step},
send_opts,
)
)
else:
vae = pipeline.vae
# vae = tinyVae(vae)
scaling_factor = vae.config.scaling_factor
image_processor = pipeline.image_processor
def callback(step: int, timestep: int, latents: torch.FloatTensor):
status.update(
"inference", step / model_inputs.get("num_inference_steps", 50)
)
# with torch.no_grad():
# image = vae.decode(latents / scaling_factor, return_dict=False)[0]
# image = image_processor.postprocess(image, output_type="pil")[0]
# image.save(f"step_{step}_img0.png")
is_sdxl = (
isinstance(model, StableDiffusionXLPipeline)
or isinstance(model, StableDiffusionXLImg2ImgPipeline)
or isinstance(model, StableDiffusionXLInpaintPipeline)
)
with torch.inference_mode():
custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)
print(
{
"callback": callback,
"**model_inputs": model_inputs,
},
)
if call_inputs.get("compel_prompts", False):
prepare_prompts(pipeline, model_inputs, is_sdxl)
try:
async_pipeline = asyncio.to_thread(
getattr(pipeline, custom_pipeline_method)
if custom_pipeline_method
else pipeline,
callback=callback,
**model_inputs,
)
# if call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
# # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
# # still broken in 0.5.1
# with autocast(device_id):
# images = (await async_pipeline).images
# else:
pipeResult = await async_pipeline
images = pipeResult.images
except Exception as err:
return {
"$error": {
"code": "PIPELINE_ERROR",
"name": type(err).__name__,
"message": str(err),
"stack": traceback.format_exc(),
}
}
images_base64 = []
image_format = call_inputs.get("image_format", "PNG")
image_opts = (
{"lossless": True} if image_format == "PNG" or image_format == "WEBP" else {}
)
for image in images:
buffered = BytesIO()
image.save(buffered, format=image_format, **image_opts)
images_base64.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
# Return the results as a dictionary
if len(images_base64) > 1:
result = result | {"images_base64": images_base64}
else:
result = result | {"image_base64": images_base64[0]}
nsfw_content_detected = pipeResult.get("nsfw_content_detected", None)
if nsfw_content_detected:
result = result | {"nsfw_content_detected": nsfw_content_detected}
# TODO, move and generalize in device.py
mem_usage = 0
if torch.cuda.is_available():
mem_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
result = result | {"$timings": getTimings(), "$mem_usage": mem_usage}
return result
================================================
FILE: api/convert_to_diffusers.py
================================================
import os
import requests
import subprocess
import torch
import json
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from diffusers.pipelines.stable_diffusion import (
StableDiffusionInpaintPipeline,
)
from utils import Storage
from device import device_id
MODEL_ID = os.environ.get("MODEL_ID", None)
CHECKPOINT_DIR = "/root/.cache/checkpoints"
CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None)
CHECKPOINT_CONFIG_URL = os.environ.get("CHECKPOINT_CONFIG_URL", None)
CHECKPOINT_ARGS = os.environ.get("CHECKPOINT_ARGS", None)
# _CONVERT_SPECIAL = os.environ.get("_CONVERT_SPECIAL", None)
def main(
model_id: str,
checkpoint_url: str,
checkpoint_config_url: str,
checkpoint_args: dict = {},
path=None,
):
if not path:
fname = checkpoint_url.split("/").pop()
path = os.path.join(CHECKPOINT_DIR, fname)
if checkpoint_config_url and checkpoint_config_url != "":
storage = Storage(checkpoint_config_url)
configPath = CHECKPOINT_DIR + "/" + path + "_config.yaml"
print(f"Downloading {checkpoint_config_url} to {configPath}...")
storage.download_file(configPath)
# specialSrc = "https://raw.githubusercontent.com/hafriedlander/diffusers/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py"
# specialPath = CHECKPOINT_DIR + "/" + "convert_special.py"
# if _CONVERT_SPECIAL:
# storage = Storage(specialSrc)
# print(f"Downloading {specialSrc} to {specialPath}")
# storage.download_file(specialPath)
# scriptPath = (
# # specialPath
# # if _CONVERT_SPECIAL
# # else
# "./diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py"
# )
print("Converting " + path + " to diffusers model " + model_id + "...", flush=True)
# These are now in main requirements.txt.
# subprocess.run(
# ["pip", "install", "omegaconf", "pytorch_lightning", "tensorboard"], check=True
# )
# Diffusers now uses requests instead, yay!
# subprocess.run(["apt-get", "install", "-y", "wget"], check=True)
# We can now specify this ourselves and don't need to modify the script.
# if device_id == "cpu":
# subprocess.run(
# [
# "sed",
# "-i",
# # Force loading into CPU
# "s/torch.load(args.checkpoint_path)/torch.load(args.checkpoint_path, map_location=torch.device('cpu'))/",
# scriptPath,
# ]
# )
# # Nice to check but also there seems to be a race condition here which
# # needs further investigation. Python docs are clear that subprocess.run()
# # will "Wait for command to complete, then return a CompletedProcess instance."
# # But it really seems as though without the grep in the middle, the script is
# # run before sed completes, or maybe there's some FS level caching gotchas.
# subprocess.run(
# [
# "grep",
# "torch.load",
# scriptPath,
# ],
# check=True,
# )
# args = [
# "python3",
# scriptPath,
# "--extract_ema",
# "--checkpoint_path",
# fname,
# "--dump_path",
# model_id,
# ]
# if checkpoint_config_url:
# args.append("--original_config_file")
# args.append(configPath)
# subprocess.run(
# args,
# check=True,
# )
# Oh yay! Diffusers abstracted this now, so much easier to use.
# But less tested. Changed on 2023-02-18. TODO, remove commented
# out code above once this has more usage.
# diffusers defaults
args = {
"scheduler_type": "pndm",
}
# our defaults
args.update(
{
"checkpoint_path_or_dict": path,
"original_config_file": configPath if checkpoint_config_url else None,
"device": device_id,
"extract_ema": True,
"from_safetensors": "safetensor" in path.lower(),
}
)
if "inpaint" in path or "Inpaint" in path:
args.update({"pipeline_class": StableDiffusionInpaintPipeline})
# user overrides
args.update(checkpoint_args)
pipe = download_from_original_stable_diffusion_ckpt(**args)
pipe.save_pretrained(model_id, safe_serialization=True)
if __name__ == "__main__":
# response = requests.get(
# "https://github.com/huggingface/diffusers/raw/main/scripts/convert_original_stable_diffusion_to_diffusers.py"
# )
# open("convert_original_stable_diffusion_to_diffusers.py", "wb").write(
# response.content
# )
if CHECKPOINT_URL and CHECKPOINT_URL != "":
checkpoint_args = json.loads(CHECKPOINT_ARGS) if CHECKPOINT_ARGS else {}
main(
MODEL_ID,
CHECKPOINT_URL,
CHECKPOINT_CONFIG_URL,
checkpoint_args=checkpoint_args,
)
================================================
FILE: api/device.py
================================================
import torch
if torch.cuda.is_available():
print("[device] CUDA (Nvidia) detected")
device_id = "cuda"
device_name = torch.cuda.get_device_name()
elif torch.backends.mps.is_available():
print("[device] MPS (MacOS Metal, Apple M1, etc) detected")
device_id = "mps"
device_name = "MPS"
else:
print("[device] CPU only - no GPU detected")
device_id = "cpu"
device_name = "CPU only"
if not torch.backends.cuda.is_built():
print(
"CUDA not available because the current PyTorch install was not "
"built with CUDA enabled."
)
if torch.backends.mps.is_built():
print(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
else:
print(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled."
)
device = torch.device(device_id)
================================================
FILE: api/download.py
================================================
# In this file, we define download_model
# It runs during container build time to get model weights built into the container
import os
from loadModel import loadModel, MODEL_IDS
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from utils import Storage
import subprocess
from pathlib import Path
import shutil
from convert_to_diffusers import main as convert_to_diffusers
from download_checkpoint import main as download_checkpoint
from status import status
import asyncio
USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
RUNTIME_DOWNLOADS = os.environ.get("RUNTIME_DOWNLOADS")
HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
# i.e. don't run during build
async def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
if RUNTIME_DOWNLOADS:
from send import send as _send
await _send(type, status, payload, send_opts)
def normalize_model_id(model_id: str, model_revision):
normalized_model_id = "models--" + model_id.replace("/", "--")
if model_revision:
normalized_model_id += "--" + model_revision
return normalized_model_id
async def download_model(
model_url=None,
model_id=None,
model_revision=None,
checkpoint_url=None,
checkpoint_config_url=None,
hf_model_id=None,
model_precision=None,
send_opts={},
pipeline_class=None,
):
print(
"download_model",
{
"model_url": model_url,
"model_id": model_id,
"model_revision": model_revision,
"hf_model_id": hf_model_id,
"checkpoint_url": checkpoint_url,
"checkpoint_config_url": checkpoint_config_url,
},
)
hf_model_id = hf_model_id or model_id
normalized_model_id = model_id
# if model_url != "": # throws an error, useful to debug stdout/stderr order
if model_url:
normalized_model_id = normalize_model_id(model_id, model_revision)
print({"normalized_model_id": normalized_model_id})
filename = model_url.split("/").pop()
if not filename:
filename = normalized_model_id + ".tar.zst"
model_file = os.path.join(MODELS_DIR, filename)
storage = Storage(
model_url, default_path=normalized_model_id + ".tar.zst", status=status
)
exists = storage.file_exists()
if exists:
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
print("model_dir", model_dir)
await asyncio.to_thread(storage.download_and_extract, model_file, model_dir)
else:
if checkpoint_url:
path = download_checkpoint(checkpoint_url)
convert_to_diffusers(
model_id=model_id,
checkpoint_url=checkpoint_url,
checkpoint_config_url=checkpoint_config_url,
path=path,
)
else:
print("Does not exist, let's try find it on huggingface")
print(
{
"model_precision": model_precision,
"model_revision": model_revision,
}
)
# This would be quicker to just model.to(device) afterwards, but
# this conveniently logs all the timings (and doesn't happen often)
print("download")
await send("download", "start", {}, send_opts)
model = loadModel(
hf_model_id,
False,
precision=model_precision,
revision=model_revision,
pipeline_class=pipeline_class,
) # download
await send("download", "done", {}, send_opts)
print("load")
model = loadModel(
hf_model_id,
True,
precision=model_precision,
revision=model_revision,
pipeline_class=pipeline_class,
) # load
# dir = "models--" + model_id.replace("/", "--") + "--dda"
dir = os.path.join(MODELS_DIR, normalized_model_id)
model.save_pretrained(dir, safe_serialization=True)
# This is all duped from train_dreambooth, need to refactor TODO XXX
await send("compress", "start", {}, send_opts)
subprocess.run(
f"tar cvf - -C {dir} . | zstd -o {model_file}",
shell=True,
check=True, # TODO, rather don't raise and return an error in JSON
)
await send("compress", "done", {}, send_opts)
subprocess.run(["ls", "-l", model_file])
await send("upload", "start", {}, send_opts)
upload_result = storage.upload_file(model_file, filename)
await send("upload", "done", {}, send_opts)
print(upload_result)
os.remove(model_file)
# leave model dir for future loads... make configurable?
# shutil.rmtree(dir)
# TODO, swap directories, inside HF's cache structure.
else:
if checkpoint_url:
path = download_checkpoint(checkpoint_url)
convert_to_diffusers(
model_id=model_id,
checkpoint_url=checkpoint_url,
checkpoint_config_url=checkpoint_config_url,
path=path,
)
else:
# do a dry run of loading the huggingface model, which will download weights at build time
loadModel(
model_id=hf_model_id,
load=False,
precision=model_precision,
revision=model_revision,
pipeline_class=pipeline_class,
)
# if USE_DREAMBOOTH:
# Actually we can re-use these from the above loaded model
# Will remove this soon if no more surprises
# for subfolder, model in [
# ["tokenizer", CLIPTokenizer],
# ["text_encoder", CLIPTextModel],
# ["vae", AutoencoderKL],
# ["unet", UNet2DConditionModel],
# ["scheduler", DDPMScheduler]
# ]:
# print(subfolder, model)
# model.from_pretrained(
# MODEL_ID,
# subfolder=subfolder,
# revision=revision,
# use_auth_token=HF_AUTH_TOKEN,
# )
if __name__ == "__main__":
asyncio.run(
download_model(
model_url=os.environ.get("MODEL_URL"),
model_id=os.environ.get("MODEL_ID"),
hf_model_id=os.environ.get("HF_MODEL_ID"),
model_revision=os.environ.get("MODEL_REVISION"),
model_precision=os.environ.get("MODEL_PRECISION"),
checkpoint_url=os.environ.get("CHECKPOINT_URL"),
checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"),
)
)
================================================
FILE: api/download_checkpoint.py
================================================
import os
from utils import Storage
CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None)
CHECKPOINT_DIR = "/root/.cache/checkpoints"
def main(checkpoint_url: str):
if not os.path.isdir(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
storage = Storage(checkpoint_url)
storage_query_fname = storage.query.get("fname")
if storage_query_fname:
fname = storage_query_fname[0]
else:
fname = checkpoint_url.split("/").pop()
path = os.path.join(CHECKPOINT_DIR, fname)
if not os.path.isfile(path):
storage.download_file(path)
return path
if __name__ == "__main__":
if CHECKPOINT_URL:
main(CHECKPOINT_URL)
================================================
FILE: api/extras/__init__.py
================================================
from .upsample import upsample
================================================
FILE: api/extras/upsample/__init__.py
================================================
from .upsample import upsample
================================================
FILE: api/extras/upsample/models.py
================================================
upsamplers = {
"RealESRGAN_x4plus": {
"name": "General - RealESRGANplus",
"weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"filename": "RealESRGAN_x4plus.pth",
"net": "RRDBNet",
"initArgs": {
"num_in_ch": 3,
"num_out_ch": 3,
"num_feat": 64,
"num_block": 23,
"num_grow_ch": 32,
"scale": 4,
},
"netscale": 4,
},
# "RealESRNet_x4plus": {
# "name": "",
# "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
# "path": "weights/RealESRNet_x4plus.pth",
# },
"RealESRGAN_x4plus_anime_6B": {
"name": "Anime - anime6B",
"weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"filename": "RealESRGAN_x4plus_anime_6B.pth",
"net": "RRDBNet",
"initArgs": {
"num_in_ch": 3,
"num_out_ch": 3,
"num_feat": 64,
"num_block": 6,
"num_grow_ch": 32,
"scale": 4,
},
"netscale": 4,
},
# "RealESRGAN_x2plus": {
# "name": "",
# "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
# "path": "weights/RealESRGAN_x2plus.pth",
# },
# "realesr-animevideov3": {
# "name": "AnimeVideo - v3",
# "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
# "path": "weights/realesr-animevideov3.pth",
# },
"realesr-general-x4v3": {
"name": "General - v3",
# [, "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" ],
"weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"filename": "realesr-general-x4v3.pth",
"net": "SRVGGNetCompact",
"initArgs": {
"num_in_ch": 3,
"num_out_ch": 3,
"num_feat": 64,
"num_conv": 32,
"upscale": 4,
"act_type": "prelu",
},
"netscale": 4,
},
}
face_enhancers = {
"GFPGAN": {
"name": "GFPGAN",
"weights": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
"filename": "GFPGANv1.4.pth",
},
}
models_by_type = {
"upsamplers": upsamplers,
"face_enhancers": face_enhancers,
}
================================================
FILE: api/extras/upsample/upsample.py
================================================
import os
import asyncio
from pathlib import Path
import base64
from io import BytesIO
import PIL
import json
import cv2
import numpy as np
import torch
import torchvision
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
from .models import models_by_type, upsamplers, face_enhancers
from status import status
from utils import Storage
from send import send
print(
{
"torch.__version__": torch.__version__,
"torchvision.__version__": torchvision.__version__,
}
)
HOME = os.path.expanduser("~")
CACHE_DIR = os.path.join(HOME, ".cache", "diffusers-api", "upsample")
def cache_path(filename):
return os.path.join(CACHE_DIR, filename)
async def assert_model_exists(src, filename, send_opts, opts={}):
dest = cache_path(filename) if not opts.get("absolutePath", None) else filename
if not os.path.exists(dest):
await send("download", "start", {}, send_opts)
storage = Storage(src, status=status)
# await storage.download_file(dest)
await asyncio.to_thread(storage.download_file, dest)
await send("download", "done", {}, send_opts)
async def download_models(send_opts={}):
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
for type in models_by_type:
models = models_by_type[type]
for model_key in models:
model = models[model_key]
await assert_model_exists(model["weights"], model["filename"], send_opts)
Path("gfpgan/weights").mkdir(parents=True, exist_ok=True)
await assert_model_exists(
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
"detection_Resnet50_Final.pth",
send_opts,
)
await assert_model_exists(
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
"parsing_parsenet.pth",
send_opts,
)
# hardcoded paths in xinntao/facexlib
filenames = ["detection_Resnet50_Final.pth", "parsing_parsenet.pth"]
for file in filenames:
if not os.path.exists(f"gfpgan/weights/{file}"):
os.symlink(cache_path(file), f"gfpgan/weights/{file}")
nets = {
"RRDBNet": RRDBNet,
"SRVGGNetCompact": SRVGGNetCompact,
}
models = {}
async def upsample(model_inputs, call_inputs, send_opts={}, startRequestId=None):
global models
# TODO, only download relevant models for this request
await download_models()
model_id = call_inputs.get("MODEL_ID", None)
if not model_id:
return {
"$error": {
"code": "MISSING_MODEL_ID",
"message": "call_inputs.MODEL_ID is required, but not given.",
}
}
model = models.get(model_id, None)
if not model:
model = models_by_type["upsamplers"].get(model_id, None)
if not model:
return {
"$error": {
"code": "MISSING_MODEL",
"message": f'Model "{model_id}" not available on this container.',
"requested": model_id,
"available": '"' + '", "'.join(models.keys()) + '"',
}
}
else:
modelModel = nets[model["net"]](**model["initArgs"])
await send(
"loadModel",
"start",
{"startRequestId": startRequestId},
send_opts,
)
upsampler = RealESRGANer(
scale=model["netscale"],
model_path=cache_path(model["filename"]),
dni_weight=None,
model=modelModel,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
)
await send(
"loadModel",
"done",
{"startRequestId": startRequestId},
send_opts,
)
model.update({"model": modelModel, "upsampler": upsampler})
models.update({model_id: model})
upsampler = model["upsampler"]
input_image = model_inputs.get("input_image", None)
if not input_image:
return {
"$error": {
"code": "NO_INPUT_IMAGE",
"message": "Missing required parameter `input_image`",
}
}
if model_id == "realesr-general-x4v3":
denoise_strength = model_inputs.get("denoise_strength", 1)
if denoise_strength != 1:
# wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
# model_path = [model_path, wdn_model_path]
# upsampler = models["realesr-general-x4v3-denoise"]
# upsampler.dni_weight = dni_weight
dni_weight = [denoise_strength, 1 - denoise_strength]
return "TODO: denoise_strength"
face_enhance = model_inputs.get("face_enhance", False)
if face_enhance:
face_enhancer = models.get("GFPGAN", None)
if not face_enhancer:
await send(
"loadModel",
"start",
{"startRequestId": startRequestId},
send_opts,
)
print("1) " + cache_path(face_enhancers["GFPGAN"]["filename"]))
face_enhancer = GFPGANer(
model_path=cache_path(face_enhancers["GFPGAN"]["filename"]),
upscale=4, # args.outscale,
arch="clean",
channel_multiplier=2,
bg_upsampler=upsampler,
)
await send(
"loadModel",
"done",
{"startRequestId": startRequestId},
send_opts,
)
models.update({"GFPGAN": face_enhancer})
if face_enhance: # Use GFPGAN for face enhancement
face_enhancer.bg_upsampler = upsampler
# image = decodeBase64Image(model_inputs.get("input_image"))
image_str = base64.b64decode(model_inputs["input_image"])
image_np = np.frombuffer(image_str, dtype=np.uint8)
# bytes = BytesIO(base64.decodebytes(bytes(model_inputs["input_image"], "utf-8")))
img = cv2.imdecode(image_np, cv2.IMREAD_UNCHANGED)
await send("inference", "start", {"startRequestId": startRequestId}, send_opts)
# Run the model
# with autocast("cuda"):
# image = pipeline(**model_inputs).images[0]
if face_enhance:
_, _, output = face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _rgb = upsampler.enhance(img, outscale=4) # TODO outscale param
image_base64 = base64.b64encode(cv2.imencode(".jpg", output)[1]).decode()
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
# Return the results as a dictionary
return {"$meta": {}, "image_base64": image_base64}
================================================
FILE: api/getPipeline.py
================================================
import time
import os, fnmatch
from diffusers import (
DiffusionPipeline,
pipelines as diffusers_pipelines,
)
from precision import torch_dtype_from_precision
HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
_pipelines = {}
_availableCommunityPipelines = None
def listAvailablePipelines():
return (
list(
filter(
lambda key: key.endswith("Pipeline"),
list(diffusers_pipelines.__dict__.keys()),
)
)
+ availableCommunityPipelines()
)
def availableCommunityPipelines():
global _availableCommunityPipelines
if not _availableCommunityPipelines:
_availableCommunityPipelines = list(
map(
lambda s: s[0:-3],
fnmatch.filter(os.listdir("diffusers/examples/community"), "*.py"),
)
)
return _availableCommunityPipelines
def clearPipelines():
"""
Clears the pipeline cache. Important to call this when changing the
loaded model, as pipelines include references to the model and would
therefore prevent memory being reclaimed after unloading the previous
model.
"""
global _pipelines
_pipelines = {}
def getPipelineClass(pipeline_name: str):
if hasattr(diffusers_pipelines, pipeline_name):
return getattr(diffusers_pipelines, pipeline_name)
elif pipeline_name in availableCommunityPipelines():
return DiffusionPipeline
def getPipelineForModel(
pipeline_name: str, model, model_id, model_revision, model_precision
):
"""
Inits a new pipeline, re-using components from a previously loaded
model. The pipeline is cached and future calls with the same
arguments will return the previously initted instance. Be sure
to call `clearPipelines()` if loading a new model, to allow the
previous model to be garbage collected.
"""
pipeline = _pipelines.get(pipeline_name)
if pipeline:
return pipeline
start = time.time()
if hasattr(diffusers_pipelines, pipeline_name):
pipeline_class = getattr(diffusers_pipelines, pipeline_name)
if hasattr(pipeline_class, "from_pipe"):
pipeline = pipeline_class.from_pipe(model)
elif hasattr(model, "components"):
pipeline = pipeline_class(**model.components)
else:
pipeline = getattr(diffusers_pipelines, pipeline_name)(
vae=model.vae,
text_encoder=model.text_encoder,
tokenizer=model.tokenizer,
unet=model.unet,
scheduler=model.scheduler,
safety_checker=model.safety_checker,
feature_extractor=model.feature_extractor,
)
elif pipeline_name in availableCommunityPipelines():
model_dir = os.path.join(MODELS_DIR, model_id)
if not os.path.isdir(model_dir):
model_dir = None
pipeline = DiffusionPipeline.from_pretrained(
model_dir or model_id,
revision=model_revision,
torch_dtype=torch_dtype_from_precision(model_precision),
custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py",
local_files_only=True,
**model.components,
)
if pipeline:
_pipelines.update({pipeline_name: pipeline})
diff = round((time.time() - start) * 1000)
print(f"Initialized {pipeline_name} for {model_id} in {diff}ms")
return pipeline
================================================
FILE: api/getScheduler.py
================================================
import torch
import os
import time
from diffusers import schedulers as _schedulers
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
SCHEDULERS = [
"DPMSolverMultistepScheduler",
"LMSDiscreteScheduler",
"DDIMScheduler",
"PNDMScheduler",
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
]
DEFAULT_SCHEDULER = os.getenv("DEFAULT_SCHEDULER", SCHEDULERS[0])
"""
# This was a nice idea but until we have default init vars for all schedulers
# via from_pretrained(), it's a no go. In any case, loading a scheduler takes time
# so better to init as needed and cache.
isScheduler = re.compile(r".+Scheduler$")
for key, val in _schedulers.__dict__.items():
if isScheduler.match(key):
schedulers.update(
{
key: val.from_pretrained(
MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
)
}
)
"""
def initScheduler(MODEL_ID: str, scheduler_id: str, download=False):
print(f"Initializing {scheduler_id} for {MODEL_ID}...")
start = time.time()
scheduler = getattr(_schedulers, scheduler_id)
if scheduler == None:
return None
model_dir = os.path.join(MODELS_DIR, MODEL_ID)
if not os.path.isdir(model_dir):
model_dir = None
inittedScheduler = scheduler.from_pretrained(
model_dir or MODEL_ID,
subfolder="scheduler",
use_auth_token=HF_AUTH_TOKEN,
local_files_only=not download,
)
diff = round((time.time() - start) * 1000)
print(f"Initialized {scheduler_id} for {MODEL_ID} in {diff}ms")
return inittedScheduler
schedulers = {}
def getScheduler(MODEL_ID: str, scheduler_id: str, download=False):
schedulersByModel = schedulers.get(MODEL_ID, None)
if schedulersByModel == None:
schedulersByModel = {}
schedulers.update({MODEL_ID: schedulersByModel})
# Check for use of old names
deprecated_map = {
"LMS": "LMSDiscreteScheduler",
"DDIM": "DDIMScheduler",
"PNDM": "PNDMScheduler",
}
scheduler_renamed = deprecated_map.get(scheduler_id, None)
if scheduler_renamed != None:
print(
f'[Deprecation Warning]: Scheduler "{scheduler_id}" is now '
f'called "{scheduler_id}". Please rename as this will '
f"stop working in a future release."
)
scheduler_id = scheduler_renamed
scheduler = schedulersByModel.get(scheduler_id, None)
if scheduler == None:
scheduler = initScheduler(MODEL_ID, scheduler_id, download)
schedulersByModel.update({scheduler_id: scheduler})
return scheduler
================================================
FILE: api/lib/__init__.py
================================================
================================================
FILE: api/lib/prompts.py
================================================
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
def prepare_prompts(pipeline, model_inputs, is_sdxl):
textual_inversion_manager = DiffusersTextualInversionManager(pipeline)
if is_sdxl:
compel = Compel(
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
# diffusers has no ti in sdxl yet
# https://github.com/huggingface/diffusers/issues/4376#issuecomment-1659016141
# textual_inversion_manager=textual_inversion_manager,
truncate_long_prompts=False,
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
conditioning, pooled = compel(model_inputs.get("prompt"))
negative_conditioning, negative_pooled = compel(
model_inputs.get("negative_prompt")
)
[
conditioning,
negative_conditioning,
] = compel.pad_conditioning_tensors_to_same_length(
[conditioning, negative_conditioning]
)
model_inputs.update(
{
"prompt": None,
"negative_prompt": None,
"prompt_embeds": conditioning,
"negative_prompt_embeds": negative_conditioning,
"pooled_prompt_embeds": pooled,
"negative_pooled_prompt_embeds": negative_pooled,
}
)
else:
compel = Compel(
tokenizer=pipeline.tokenizer,
text_encoder=pipeline.text_encoder,
textual_inversion_manager=textual_inversion_manager,
truncate_long_prompts=False,
)
conditioning = compel(model_inputs.get("prompt"))
negative_conditioning = compel(model_inputs.get("negative_prompt"))
[
conditioning,
negative_conditioning,
] = compel.pad_conditioning_tensors_to_same_length(
[conditioning, negative_conditioning]
)
model_inputs.update(
{
"prompt": None,
"negative_prompt": None,
"prompt_embeds": conditioning,
"negative_prompt_embeds": negative_conditioning,
}
)
================================================
FILE: api/lib/textual_inversions.py
================================================
import json
import re
import os
import asyncio
from utils import Storage
from .vars import MODELS_DIR
last_textual_inversions = None
last_textual_inversion_model = None
loaded_textual_inversion_tokens = []
tokenRe = re.compile(
r"[#&]{1}fname=(?P<fname>[^\.]+)\.(?:pt|safetensors)(&token=(?P<token>[^&]+))?$"
)
def strMap(str: str):
match = re.search(tokenRe, str)
# print(match)
if match:
return match.group("token") or match.group("fname")
def extract_tokens_from_list(textual_inversions: list):
return list(map(strMap, textual_inversions))
async def handle_textual_inversions(textual_inversions: list, model, status):
global last_textual_inversions
global last_textual_inversion_model
global loaded_textual_inversion_tokens
textual_inversions_str = json.dumps(textual_inversions)
if (
textual_inversions_str != last_textual_inversions
or model is not last_textual_inversion_model
):
if model is not last_textual_inversion_model:
loaded_textual_inversion_tokens = []
last_textual_inversion_model = model
# print({"textual_inversions": textual_inversions})
# tokens_to_load = extract_tokens_from_list(textual_inversions)
# print({"tokens_loaded": loaded_textual_inversion_tokens})
# print({"tokens_to_load": tokens_to_load})
#
# for token in loaded_textual_inversion_tokens:
# if token not in tokens_to_load:
# print("[TextualInversion] Removing uneeded token: " + token)
# del pipeline.tokenizer.get_vocab()[token]
# # del pipeline.text_encoder.get_input_embeddings().weight.data[token]
# pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
#
# loaded_textual_inversion_tokens = tokens_to_load
last_textual_inversions = textual_inversions_str
for textual_inversion in textual_inversions:
storage = Storage(textual_inversion, no_raise=True, status=status)
if storage:
storage_query_fname = storage.query.get("fname")
if storage_query_fname:
fname = storage_query_fname[0]
else:
fname = textual_inversion.split("/").pop()
path = os.path.join(MODELS_DIR, "textual_inversion--" + fname)
if not os.path.exists(path):
await asyncio.to_thread(storage.download_file, path)
print("Load textual inversion " + path)
token = storage.query.get("token", None)
if token not in loaded_textual_inversion_tokens:
model.load_textual_inversion(
path, token=token, local_files_only=True
)
loaded_textual_inversion_tokens.append(token)
else:
print("Load textual inversion " + textual_inversion)
model.load_textual_inversion(textual_inversion)
else:
print("No changes to textual inversions since last call")
================================================
FILE: api/lib/textual_inversions_test.py
================================================
import unittest
from .textual_inversions import extract_tokens_from_list
class TextualInversionsTest(unittest.TestCase):
def test_extract_tokens_query_fname(self):
tis = ["https://civitai.com/api/download/models/106132#fname=4nj0lie.pt"]
tokens = extract_tokens_from_list(tis)
self.assertEqual(tokens[0], "4nj0lie")
def test_extract_tokens_query_token(self):
tis = [
"https://civitai.com/api/download/models/106132#fname=4nj0lie.pt&token=4nj0lie"
]
tokens = extract_tokens_from_list(tis)
self.assertEqual(tokens[0], "4nj0lie")
================================================
FILE: api/lib/vars.py
================================================
import os
RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
MODEL_ID = os.environ.get("MODEL_ID")
PIPELINE = os.environ.get("PIPELINE")
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
================================================
FILE: api/loadModel.py
================================================
import torch
import os
from diffusers import pipelines as _pipelines, AutoPipelineForText2Image
from getScheduler import getScheduler, DEFAULT_SCHEDULER
from precision import torch_dtype_from_precision
from device import device
import time
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
PIPELINE = os.getenv("PIPELINE")
USE_DREAMBOOTH = True if os.getenv("USE_DREAMBOOTH") == "1" else False
HOME = os.path.expanduser("~")
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
MODEL_IDS = [
"CompVis/stable-diffusion-v1-4",
"hakurei/waifu-diffusion",
# "hakurei/waifu-diffusion-v1-3", - not as diffusers yet
"runwayml/stable-diffusion-inpainting",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2"
"stabilityai/stable-diffusion-2-base"
"stabilityai/stable-diffusion-2-inpainting",
]
def loadModel(
model_id: str,
load=True,
precision=None,
revision=None,
send_opts={},
pipeline_class=None,
):
torch_dtype = torch_dtype_from_precision(precision)
if revision == "":
revision = None
print(
"loadModel",
{
"model_id": model_id,
"load": load,
"precision": precision,
"revision": revision,
"pipeline_class": pipeline_class,
},
)
if not pipeline_class:
pipeline_class = AutoPipelineForText2Image
pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
print("pipeline", pipeline_class)
print(
("Loading" if load else "Downloading")
+ " model: "
+ model_id
+ (f" ({revision})" if revision else "")
)
scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)
model_dir = os.path.join(MODELS_DIR, model_id)
if not os.path.isdir(model_dir):
model_dir = None
from_pretrained = time.time()
model = pipeline.from_pretrained(
model_dir or model_id,
revision=revision,
torch_dtype=torch_dtype,
use_auth_token=HF_AUTH_TOKEN,
scheduler=scheduler,
local_files_only=load,
# Work around https://github.com/huggingface/diffusers/issues/1246
# low_cpu_mem_usage=False if USE_DREAMBOOTH else True,
)
from_pretrained = round((time.time() - from_pretrained) * 1000)
if load:
to_gpu = time.time()
model.to(device)
to_gpu = round((time.time() - to_gpu) * 1000)
print(f"Loaded from disk in {from_pretrained} ms, to gpu in {to_gpu} ms")
else:
print(f"Downloaded in {from_pretrained} ms")
return model if load else None
================================================
FILE: api/precision.py
================================================
import os
import torch
DEPRECATED_PRECISION = os.getenv("PRECISION")
MODEL_PRECISION = os.getenv("MODEL_PRECISION") or DEPRECATED_PRECISION
MODEL_REVISION = os.getenv("MODEL_REVISION")
if DEPRECATED_PRECISION:
print("Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION")
print("Your setup still works but in a future release, this will throw an error")
if MODEL_PRECISION and not MODEL_REVISION:
print("Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please")
print(f'explicitly set MODEL_REVISION="{MODEL_PRECISION}" if that\'s what you')
print("want.")
def revision_from_precision(precision=MODEL_PRECISION):
# return precision if precision else None
raise Exception("revision_from_precision no longer supported")
def torch_dtype_from_precision(precision=MODEL_PRECISION):
if precision == "fp16":
return torch.float16
return None
def torch_dtype_from_precision(precision=MODEL_PRECISION):
if precision == "fp16":
return torch.float16
return None
================================================
FILE: api/send.py
================================================
import json
import os
import datetime
import time
import requests
import hashlib
from requests_futures.sessions import FuturesSession
from status import status as statusInstance
print()
environ = os.environ.copy()
for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "HF_AUTH_TOKEN"]:
if environ.get(key, None):
environ[key] = "XXX"
print(environ)
print()
def get_now():
return round(time.time() * 1000)
SEND_URL = os.getenv("SEND_URL")
if SEND_URL == "":
SEND_URL = None
SIGN_KEY = os.getenv("SIGN_KEY", "")
if SIGN_KEY == "":
SIGN_KEY = None
futureSession = FuturesSession()
container_id = os.getenv("CONTAINER_ID")
if not container_id:
with open("/proc/self/mountinfo") as file:
line = file.readline().strip()
while line:
if "/containers/" in line:
container_id = line.split("/containers/")[
-1
] # Take only text to the right
container_id = container_id.split("/")[0] # Take only text to the left
break
line = file.readline().strip()
init_used = False
def clearSession(force=False):
global session
global init_used
if init_used or force:
session = {"_ctime": get_now()}
else:
init_used = True
def getTimings():
timings = {}
for key in session.keys():
if key == "_ctime":
continue
start = session[key].get("start", None)
done = session[key].get("done", None)
if start and done:
timings.update({key: session[key]["done"] - session[key]["start"]})
else:
timings.update({key: -1})
return timings
async def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
now = get_now()
send_url = opts.get("SEND_URL", SEND_URL)
sign_key = opts.get("SIGN_KEY", SIGN_KEY)
if status == "start":
session.update({type: {"start": now, "last_time": now}})
elif status == "done":
session[type].update({"done": now, "diff": now - session[type]["start"]})
else:
session[type]["last_time"] = now
data = {
"type": type,
"status": status,
"container_id": container_id,
"time": now,
"t": now - session["_ctime"],
"tsl": now - session[type]["last_time"],
"payload": payload,
}
if status == "start":
statusInstance.update(type, 0.0)
elif status == "done":
statusInstance.update(type, 1.0)
if send_url and sign_key:
input = json.dumps(data, separators=(",", ":")) + sign_key
sig = hashlib.md5(input.encode("utf-8")).hexdigest()
data["sig"] = sig
print(datetime.datetime.now(), data)
if send_url:
futureSession.post(send_url, json=data)
response = opts.get("response")
if response:
print("streaming above")
await response.send(json.dumps(data) + "\n")
# try:
# requests.post(send_url, json=data) # , timeout=0.0000000001)
# except requests.exceptions.ReadTimeout:
# except requests.exceptions.RequestException as error:
# print(error)
# pass
clearSession(True)
================================================
FILE: api/server.py
================================================
# Do not edit if deploying to Banana Serverless
# This file is boilerplate for the http server, and follows a strict interface.
# Instead, edit the init() and inference() functions in app.py
from sanic import Sanic, response
from sanic_ext import Extend
import subprocess
import app as user_src
import traceback
import os
import json
# We do the model load-to-GPU step on server startup
# so the model object is available globally for reuse
user_src.init()
# Create the http server app
server = Sanic("my_app")
server.config.CORS_ORIGINS = os.getenv("CORS_ORIGINS") or "*"
server.config.RESPONSE_TIMEOUT = 60 * 60 # 1 hour (training can be long)
Extend(server)
# Healthchecks verify that the environment is correct on Banana Serverless
@server.route("/healthcheck", methods=["GET"])
def healthcheck(request):
# dependency free way to check if GPU is visible
gpu = False
out = subprocess.run("nvidia-smi", shell=True)
if out.returncode == 0: # success state on shell command
gpu = True
return response.json({"state": "healthy", "gpu": gpu})
# Inference POST handler at '/' is called for every http call from Banana
@server.route("/", methods=["POST"])
async def inference(request):
try:
all_inputs = response.json.loads(request.json)
except:
all_inputs = request.json
call_inputs = all_inputs.get("callInputs", None)
stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0
streaming_response = None
if stream_events:
streaming_response = await request.respond(content_type="application/x-ndjson")
try:
output = await user_src.inference(all_inputs, streaming_response)
except Exception as err:
print(err)
output = {
"$error": {
"code": "APP_INFERENCE_ERROR",
"name": type(err).__name__,
"message": str(err),
"stack": traceback.format_exc(),
}
}
if stream_events:
await streaming_response.send(json.dumps(output) + "\n")
else:
return response.json(output)
if __name__ == "__main__":
server.run(host="0.0.0.0", port="8000", workers=1)
================================================
FILE: api/status.py
================================================
class Status:
def __init__(self):
self.type = "init"
self.progress = 0.0
def update(self, type, progress):
self.type = type
self.progress = progress
def get(self):
return {"type": self.type, "progress": self.progress}
status = Status()
================================================
FILE: api/tests.py
================================================
from test import runTest
def test_memory_free_on_swap_model():
"""
Make sure memory is freed when swapping models at runtime.
"""
result = runTest(
"txt2img",
{},
{
"MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
"MODEL_PRECISION": "", # full precision
"MODEL_URL": "s3://",
},
{"num_inference_steps": 1},
)
mem_usage = list()
mem_usage.append(result["$mem_usage"])
result = runTest(
"txt2img",
{},
{
"MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
"MODEL_PRECISION": "fp16", # half precision
"MODEL_URL": "s3://",
},
{"num_inference_steps": 1},
)
mem_usage.append(result["$mem_usage"])
print({"mem_usage": mem_usage})
# Assert that less memory used when unloading fp32 model and
# loading the fp16 variant in its place
assert mem_usage[1] < mem_usage[0]
================================================
FILE: api/train_dreambooth.py
================================================
# Based on https://github.com/huggingface/diffusers/commits/main/examples/dreambooth/train_dreambooth.py
# Synced to commit b9feed87958c27074b0618cc543696c05f58e2c9 on 2023-07-12
# Reasons for not using that file directly:
#
# 1) Use our already loded model from `init()`
# 2) Callback to run after every iteration
# Deps
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import gc
import hashlib
import itertools
import logging
import math
import os
import shutil
import warnings
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, model_info, upload_folder
from packaging import version
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# DDA
from send import send as _send
from utils import Storage
import subprocess
import re
import shutil
import asyncio
# Our original code in docker-diffusers-api:
HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
asyncio.run((_send(type, status, payload, send_opts)))
def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opts):
# required inputs: instance_images instance_prompt
params = {
# Defaults
"pretrained_model_name_or_path": model_id, # DDA, TODO
# Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be
# float32 precision.
"revision": None,
"tokenizer_name": None,
"instance_data_dir": "instance_data_dir", # DDA TODO
"class_data_dir": "class_data_dir", # DDA, was: None,
# instance_prompt
"class_prompt": None,
"with_prior_preservation": False,
"prior_loss_weight": 1.0,
"num_class_images": 100,
"output_dir": "text-inversion-model",
"seed": None,
"resolution": 512,
# Whether to center crop the input images to the resolution. If not set, the images will be randomly
# cropped. The images will be resized to the resolution first before cropping.
"center_crop": False,
# Whether to train the text encoder. If set, the text encoder should be float32 precision.
"train_text_encoder": None,
"train_batch_size": 1, # DDA, was: 4
"sample_batch_size": 1, # DDA, was: 4,
"num_train_epochs": 1,
"max_train_steps": 800, # DDA, was: None,
# Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`.
# In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.
# Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.
# See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step
# instructions.
"checkpointing_steps": 1000000000, # DDA, was: 500
# Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.
# See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state
# for more details
"checkpoints_total_limit": None,
"resume_from_checkpoint": None,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True, # DDA was: None (needed for 16GB)
"learning_rate": 5e-6,
"scale_lr": False,
"lr_scheduler": "constant",
"lr_warmup_steps": 0, # DDA, was: 500,
"lr_num_cycles": 1,
# Power factor of the polynomial scheduler
"lr_power": 1.0,
"use_8bit_adam": True, # DDA, was: None (needed for 16GB)
# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
"dataloader_num_workers": 0,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"adam_weight_decay": 1e-6,
"adam_epsilon": 1e-08,
"max_grad_norm": 1.0,
"push_to_hub": None,
"hub_token": HF_AUTH_TOKEN,
"hub_model_id": None,
"logging_dir": "logs",
# Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
"allow_tf32": None,
# The integration to report the results and logs to. Supported platforms are `"tensorboard"`
# (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.
"report_to": "tensorboard",
# A prompt that is used during validation to verify that the model is learning.
"validation_prompt": None,
# Number of images that should be generated during validation with `validation_prompt`
"num_validation_images": 4,
# Run validation every X steps. Validation consists of running the prompt
# `args.validation_prompt` multiple times: `args.num_validation_images`
# and logging the images.
"validation_steps": 100,
"mixed_precision": "fp16", # DDA, was: None
# Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=
# 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32.
"prior_generation_precision": None, # "no", "fp32", "fp16", "bf16"
"local_rank": -1,
"enable_xformers_memory_efficient_attention": None,
# Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain
# behaviors, so disable this argument if it causes any problems. More info:
# https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
"set_grads_to_none": None,
# Fine-tuning against a modified noise"
# See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information.
"offset_noise": False,
# Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.
"pre_compute_text_embeddings": False,
# The maximum length of the tokenizer. If not set, will default to the tokenizer's max length."
"tokenizer_max_length": None,
# Whether to use attention mask for the text encoder
"text_encoder_use_attention_mask": False,
# Set to not save text encoder
"skip_save_text_encoder": False,
# Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.
"validation_images": None,
# The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.
"class_labels_conditioning": None,
}
instance_images = model_inputs["instance_images"]
del model_inputs["instance_images"]
params.update(model_inputs)
print(model_inputs)
args = argparse.Namespace(**params)
print(args)
if args.train_text_encoder and args.pre_compute_text_embeddings:
raise ValueError(
"`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`"
)
result = {}
if not args.push_to_hub and call_inputs.get("dest_url", None) == None:
print()
print("WARNING: Neither modelInputs.push_to_hub nor callInputs.dest_url")
print("was given. After training, your model won't be uploaded anywhere.")
print()
result.update({"no_upload": True})
# TODO, not save at all... we're just getting it working
# if its a hassle, in interim, at least save to unique dir
if not os.path.exists(args.instance_data_dir):
os.mkdir(args.instance_data_dir)
for i, image in enumerate(instance_images):
image.save(args.instance_data_dir + "/image" + str(i) + ".png")
subprocess.run(["ls", "-l", args.instance_data_dir])
result = result | main(args, pipeline, send_opts=send_opts)
dest_url = call_inputs.get("dest_url")
if dest_url:
storage = Storage(dest_url)
filename = storage.path if storage.path != "" else args.output_dir
filename = filename.split("/").pop()
print(filename)
if not re.search(r"\.", filename):
filename += ".tar.zstd"
print(filename)
# fp16 model timings: zip 1m20s, tar+zstd 4s and a tiny bit smaller!
send("compress", "start", {}, send_opts)
# TODO, steaming upload (turns out docker disk write is super slow)
subprocess.run(
f"tar cvf - -C {args.output_dir} . | zstd -o {filename}",
shell=True,
check=True, # TODO, rather don't raise and return an error in JSON
)
send("compress", "done", {}, send_opts)
subprocess.run(["ls", "-l", filename])
send("upload", "start", {}, send_opts)
upload_result = storage.upload_file(filename, filename)
send("upload", "done", {}, send_opts)
print(upload_result)
os.remove(filename)
# Cleanup
shutil.rmtree(args.output_dir)
shutil.rmtree(args.class_data_dir, ignore_errors=True)
return result
# What follows is mostly the original train_dreambooth.py
# Any changes are marked with in comments with [DDA].
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.19.0.dev0")
logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
prompt=str,
repo_folder=None,
pipeline: DiffusionPipeline = None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
# DreamBooth - {repo_id}
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
You can find some example images in the following. \n
{img_str}
DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(
text_encoder,
tokenizer,
unet,
vae,
args,
accelerator,
weight_dtype,
epoch,
prompt_embeds,
negative_prompt_embeds,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline_args = {}
if vae is not None:
pipeline_args["vae"] = vae
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
torch_dtype=weight_dtype,
**pipeline_args,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.pre_compute_text_embeddings:
pipeline_args = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
}
else:
pipeline_args = {"prompt": args.validation_prompt}
# run inference
generator = (
None
if args.seed is None
else torch.Generator(device=accelerator.device).manual_seed(args.seed)
)
images = []
if args.validation_images is None:
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(
**pipeline_args, num_inference_steps=25, generator=generator
).images[0]
images.append(image)
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[
0
]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(
"validation", np_images, epoch, dataformats="NHWC"
)
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
RobertaSeriesModelWithTransformation,
)
return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
class_num=None,
size=512,
center_crop=False,
encoder_hidden_states=None,
instance_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.instance_prompt_encoder_hidden_states = (
instance_prompt_encoder_hidden_states
)
self.tokenizer_max_length = tokenizer_max_length
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError(
f"Instance {self.instance_data_root} images root doesn't exists."
)
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(size)
if center_crop
else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(
self.instance_images_path[index % self.num_instance_images]
)
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
else:
text_inputs = tokenize_prompt(
self.tokenizer,
self.instance_prompt,
tokenizer_max_length=self.tokenizer_max_length,
)
example["instance_prompt_ids"] = text_inputs.input_ids
example["instance_attention_mask"] = text_inputs.attention_mask
if self.class_data_root:
class_image = Image.open(
self.class_images_path[index % self.num_class_images]
)
class_image = exif_transpose(class_image)
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
if self.instance_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer,
self.class_prompt,
tokenizer_max_length=self.tokenizer_max_length,
)
example["class_prompt_ids"] = class_text_inputs.input_ids
example["class_attention_mask"] = class_text_inputs.attention_mask
return example
def collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
if has_attention_mask:
attention_mask += [example["class_attention_mask"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
if has_attention_mask:
attention_mask = torch.cat(attention_mask, dim=0)
batch["attention_mask"] = attention_mask
return batch
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def model_has_vae(args):
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
if os.path.isdir(args.pretrained_model_name_or_path):
config_file_name = os.path.join(
args.pretrained_model_name_or_path, config_file_name
)
return os.path.isfile(config_file_name)
else:
files_in_repo = model_info(
args.pretrained_model_name_or_path, revision=args.revision
).siblings
return any(file.rfilename == config_file_name for file in files_in_repo)
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
else:
max_length = tokenizer.model_max_length
text_inputs = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
return text_inputs
def encode_prompt(
text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None
):
text_input_ids = input_ids.to(text_encoder.device)
if text_encoder_use_attention_mask:
attention_mask = attention_mask.to(text_encoder.device)
else:
attention_mask = None
prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
return prompt_embeds
def main(args, init_pipeline, send_opts):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError(
"Make sure to install wandb if you want to use it for logging during training."
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if (
args.train_text_encoder
and args.gradient_accumulation_steps > 1
and accelerator.num_processes > 1
):
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
# DDA
# torch_dtype = (
# torch.float16 if accelerator.device.type == "cuda" else torch.float32
# )
# if args.prior_generation_precision == "fp32":
# torch_dtype = torch.float32
# elif args.prior_generation_precision == "fp16":
# torch_dtype = torch.float16
# elif args.prior_generation_precision == "bf16":
# torch_dtype = torch.bfloat16
# DDA
pipeline = init_pipeline
pipeline.safety_checker = None
# pipeline = DiffusionPipeline.from_pretrained(
# args.pretrained_model_name_or_path,
# torch_dtype=torch_dtype,
# safety_checker=None,
# revision=args.revision,
# )
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(
sample_dataset, batch_size=args.sample_batch_size
)
sample_dataloader = accelerator.prepare(sample_dataloader)
# pipeline.to(accelerator.device) # DDA already done
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not accelerator.is_local_main_process,
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = (
class_images_dir
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
)
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
).repo_id
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name, revision=args.revision, use_fast=False
)
elif args.pretrained_model_name_or_path:
tokenizer = init_pipeline.components["tokenizer"] # DDA
# tokenizer = AutoTokenizer.from_pretrained(
# args.pretrained_model_name_or_path,
# subfolder="tokenizer",
# revision=args.revision,
# use_auth_token=args.hub_token, # DDA
# local_files_only=True, # DDA
# )
# import correct text encoder class
# DDA
# text_encoder_cls = import_model_class_from_model_name_or_path(
# args.pretrained_model_name_or_path,
# args.revision
# )
# Load scheduler and models
# noise_scheduler = DDPMScheduler.from_pretrained(
# args.pretrained_model_name_or_path,
# subfolder="scheduler",
# use_auth_token=args.hub_token, # DDA
# local_files_only=True, # DDA
# )
# text_encoder = text_encoder_cls.from_pretrained(
# args.pretrained_model_name_or_path,
# subfolder="text_encoder",
# revision=args.revision,
# use_auth_token=args.hub_token, # DDA
# local_files_only=True, # DDA
# )
# if model_has_vae(args):
# vae = AutoencoderKL.from_pretrained(
# args.pretrained_model_name_or_path,
# subfolder="vae",
# revision=args.revision
# use_auth_token=args.hub_token, # DDA
# local_files_only=True, # DDA
# )
# else:
# vae = None
# unet = UNet2DConditionModel.from_pretrained(
# args.pretrained_model_name_or_path,
# subfolder="unet",
# revision=args.revision,
# use_auth_token=args.hub_token, # DDA
# local_files_only=True, # DDA
# )
# print("pipeline.disable_xformers_memory_efficient_attention()")
# init_pipeline.disable_xformers_memory_efficient_attention()
noise_scheduler = init_pipeline.components["scheduler"] # DDA
text_encoder = init_pipeline.components["text_encoder"] # DDA
vae = init_pipeline.components["vae"] # DDA
unet = init_pipeline.components["unet"] # DDA
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = (
"unet"
if isinstance(model, type(accelerator.unwrap_model(unet)))
else "text_encoder"
)
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(
input_dir, subfolder="text_encoder"
)
model.config = load_model.config
else:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(
input_dir, subfolder="unet"
)
model.register_to_config(**load_model.config)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# TODO, how does this affect things outside of train_dreambooth?
if vae is not None:
vae.requires_grad_(False)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# Check that all trainable models are in full precision
low_precision_error_string = (
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if (
args.train_text_encoder
and accelerator.unwrap_model(text_encoder).dtype != torch.float32
):
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate
* args.gradient_accumulation_steps
* args.train_batch_size
* accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.pre_compute_text_embeddings:
def compute_text_embeddings(prompt):
with torch.no_grad():
text_inputs = tokenize_prompt(
tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length
)
prompt_embeds = encode_prompt(
text_encoder,
text_inputs.input_ids,
text_inputs.attention_mask,
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
return prompt_embeds
pre_computed_encoder_hidden_states = compute_text_embeddings(
args.instance_prompt
)
validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
if args.validation_prompt is not None:
validation_prompt_encoder_hidden_states = compute_text_embeddings(
args.validation_prompt
)
else:
validation_prompt_encoder_hidden_states = None
if args.instance_prompt is not None:
pre_computed_instance_prompt_encoder_hidden_states = (
compute_text_embeddings(args.instance_prompt)
)
else:
pre_computed_instance_prompt_encoder_hidden_states = None
text_encoder = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None
pre_computed_instance_prompt_encoder_hidden_states = None
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
class_num=args.num_class_images,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
unet,
text_encoder,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae and text_encoder to device and cast to weight_dtype
if vae is not None:
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder and text_encoder is not None:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth", config=vars(args))
# Train!
total_batch_size = (
args.train_batch_size
* accelerator.num_processes
* args.gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (
num_update_steps_per_epoch * args.gradient_accumulation_steps
)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(global_step, args.max_train_steps),
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
# DDA
send("training", "start", {}, send_opts)
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if (
args.resume_from_checkpoint
and epoch == first_epoch
and step < resume_step
):
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
if vae is not None:
# Convert images to latent space
model_input = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype)
).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
else:
model_input = pixel_values
# Sample noise that we'll add to the model input
if args.offset_noise:
noise = torch.randn_like(model_input) + 0.1 * torch.randn(
model_input.shape[0],
model_input.shape[1],
1,
1,
device=model_input.device,
)
else:
noise = torch.randn_like(model_input)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz,),
device=model_input.device,
)
timesteps = timesteps.long()
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(
model_input, noise, timesteps
)
# Get the text embedding for conditioning
if args.pre_compute_text_embeddings:
encoder_hidden_states = batch["input_ids"]
else:
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
batch["attention_mask"],
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat(
[noisy_model_input, noisy_model_input], dim=1
)
if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None
# Predict the noise residual
model_pred = unet(
noisy_model_input,
timesteps,
encoder_hidden_states,
class_labels=class_labels,
).sample
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
# Compute prior loss
prior_loss = F.mse_loss(
model_pred_prior.float(), target_prior.float(), reduction="mean"
)
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [
d for d in checkpoints if d.startswith("checkpoint")
]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1])
)
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = (
len(checkpoints) - args.checkpoints_total_limit + 1
)
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(
f"removing checkpoints: {', '.join(removing_checkpoints)}"
)
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(
args.output_dir, removing_checkpoint
)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(
args.output_dir, f"checkpoint-{global_step}"
)
pipeline.save_pretrained(save_path)
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
images = []
if (
args.validation_prompt is not None
and global_step % args.validation_steps == 0
):
images = log_validation(
text_encoder,
tokenizer,
unet,
vae,
args,
accelerator,
weight_dtype,
epoch,
validation_prompt_encoder_hidden_states,
validation_prompt_negative_prompt_embeds,
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
send("training", "done", {}, send_opts) # DDA
if accelerator.is_main_process:
pipeline_args = {}
if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
if args.skip_save_text_encoder:
pipeline_args["text_encoder"] = None
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
**pipeline_args,
local_files_only=True, # DDA
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = pipeline.scheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline.save_pretrained(args.output_dir, safe_serialization=True)
if args.push_to_hub:
# DDA
send("upload", "start", {}, send_opts)
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
pipeline=pipeline,
)
# repo.push_to_hub(
# commit_message="End of training",
# # DDA need to think about this, quite nice to not block, then could
# # upload while training next request. But, timeout will kill an unused
# # process... what else?
# blocking=True, # DDA, was: False,
# auto_lfs_prune=True,
# )
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
# DDA
# https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L3379
# Whether or not to run this method in the background. Background jobs are run sequentially without
# blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects)
# object. Defaults to `False`.
# run_as_future: TODO
)
# DDA
send("upload", "done", {}, send_opts)
accelerator.end_training()
# DDA
return {"done": True}
================================================
FILE: api/utils/__init__.py
================================================
from .storage import Storage
================================================
FILE: api/utils/storage/BaseStorage.py
================================================
import os
import re
import subprocess
from abc import ABC, abstractmethod
import xtarfile as tarfile
class BaseArchive(ABC):
def __init__(self, path, status=None):
self.path = path
self.status = status
def updateStatus(self, type, progress):
if self.status:
self.status.update(type, progress)
def extract(self):
print("TODO")
def splitext(self):
base, ext = os.path.splitext(self.path)
base, subext = os.path.splitext(base)
return base, ext, subext
class TarArchive(BaseArchive):
@staticmethod
def test(path):
return re.search(r"\.tar", path)
def extract(self, dir, dry_run=False):
self.updateStatus("extract", 0)
if not dir:
base, ext, subext = self.splitext()
parent_dir = os.path.dirname(self.path)
dir = os.path.join(parent_dir, base)
if not dry_run:
os.mkdir(dir)
def track_progress(tar):
i = 0
members = tar.getmembers()
for member in members:
i += 1
self.updateStatus("extract", i / len(members))
yield member
print("Extracting to " + dir)
with tarfile.open(self.path, "r") as tar:
tar.extractall(path=dir, members=track_progress(tar))
tar.close()
subprocess.run(["ls", "-l", dir])
os.remove(self.path)
self.updateStatus("extract", 1)
return dir # , base, ext, subext
archiveClasses = [TarArchive]
def Archive(path, **kwargs):
for ArchiveClass in archiveClasses:
if ArchiveClass.test(path):
return ArchiveClass(path, **kwargs)
class BaseStorage(ABC):
@staticmethod
@abstractmethod
def test(url):
return re.search(r"^https?://", url)
def __init__(self, url, **kwargs):
self.url = url
self.status = kwargs.get("status", None)
self.query = {}
def updateStatus(self, type, progress):
if self.status:
self.status.update(type, progress)
def splitext(self):
base, ext = os.path.splitext(self.url)
base, subext = os.path.splitext(base)
return base, ext, subext
def get_filename(self):
return self.url.split("/").pop()
@abstractmethod
def download_file(self, dest):
"""Download the file to `dest`"""
pass
def download_and_extract(self, fname, dir=None, dry_run=False):
"""
Downloads the file, and if it's an archive, extract it too. Returns
the filename if not, or directory name (fname without extension) if
it was.
"""
if not fname:
fname = self.get_filename()
archive = Archive(fname, status=self.status)
if archive:
# TODO, streaming pipeline
self.download_file(fname)
return archive.extract(dir)
else:
self.download_file(fname)
return fname
================================================
FILE: api/utils/storage/BaseStorage_test.py
================================================
import unittest
from . import Storage, S3Storage, HTTPStorage
class BaseStorageTest(unittest.TestCase):
def test_get_filename(self):
storage = Storage("http://host.com/dir/file.tar.zst")
self.assertEqual(storage.get_filename(), "file.tar.zst")
class Download_and_extract(unittest.TestCase):
def test_file_only(self):
storage = Storage("http://host.com/dir/file.bin")
result = storage.download_and_extract(dry_run=True)
self.assertEqual(result, "file.bin")
def test_file_archive(self):
storage = Storage("http://host.com/dir/file.tar.zst")
result, base, ext, subext = storage.download_and_extract(dry_run=True)
self.assertEqual(result, "file")
self.assertEqual(base, "file")
self.assertEqual(ext, "tar")
self.assertEqual(subext, "zst")
================================================
FILE: api/utils/storage/HTTPStorage.py
================================================
import re
import os
import time
import requests
from tqdm
gitextract_akizhtm1/ ├── .circleci/ │ └── config.yml ├── .devcontainer/ │ ├── devcontainer.json │ └── local.example.env ├── .gitignore ├── .vscode/ │ ├── settings.json │ └── tasks.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── api/ │ ├── app.py │ ├── convert_to_diffusers.py │ ├── device.py │ ├── download.py │ ├── download_checkpoint.py │ ├── extras/ │ │ ├── __init__.py │ │ └── upsample/ │ │ ├── __init__.py │ │ ├── models.py │ │ └── upsample.py │ ├── getPipeline.py │ ├── getScheduler.py │ ├── lib/ │ │ ├── __init__.py │ │ ├── prompts.py │ │ ├── textual_inversions.py │ │ ├── textual_inversions_test.py │ │ └── vars.py │ ├── loadModel.py │ ├── precision.py │ ├── send.py │ ├── server.py │ ├── status.py │ ├── tests.py │ ├── train_dreambooth.py │ └── utils/ │ ├── __init__.py │ └── storage/ │ ├── BaseStorage.py │ ├── BaseStorage_test.py │ ├── HTTPStorage.py │ ├── S3Storage.py │ ├── S3Storage_test.py │ ├── __init__.py │ └── __init__test.py ├── build ├── docs/ │ ├── internal_safetensor_cache_flow.md │ └── storage.md ├── install.sh ├── package.json ├── prime.sh ├── release.config.js ├── requirements.txt ├── run.sh ├── run_integration_tests_on_lambda.sh ├── scripts/ │ ├── devContainerPostCreate.sh │ ├── devContainerServer.sh │ ├── patchmatch-setup.sh │ ├── permutations.yaml │ └── permute.sh ├── test.py ├── tests/ │ ├── __init__.py │ └── integration/ │ ├── __init__.py │ ├── conftest.py │ ├── lib.py │ ├── requirements.txt │ ├── test_attn_procs.py │ ├── test_build_download.py │ ├── test_cloud_cache.py │ ├── test_dreambooth.py │ ├── test_general.py │ ├── test_loras.py │ └── test_memory.py ├── touch └── update.sh
SYMBOL INDEX (157 symbols across 34 files)
FILE: api/app.py
function tinyVae (line 73) | def tinyVae(origVae: AutoencoderKL):
function init (line 93) | def init():
function decodeBase64Image (line 134) | def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
function getFromUrl (line 140) | def getFromUrl(url: str, name: str) -> PIL.Image:
function truncateInputs (line 147) | def truncateInputs(inputs: dict):
function inference (line 169) | async def inference(all_inputs: dict, response) -> dict:
FILE: api/convert_to_diffusers.py
function main (line 23) | def main(
FILE: api/download.py
function send (line 27) | async def send(type: str, status: str, payload: dict = {}, send_opts: di...
function normalize_model_id (line 34) | def normalize_model_id(model_id: str, model_revision):
function download_model (line 41) | async def download_model(
FILE: api/download_checkpoint.py
function main (line 8) | def main(checkpoint_url: str):
FILE: api/extras/upsample/upsample.py
function cache_path (line 35) | def cache_path(filename):
function assert_model_exists (line 39) | async def assert_model_exists(src, filename, send_opts, opts={}):
function download_models (line 49) | async def download_models(send_opts={}):
function upsample (line 86) | async def upsample(model_inputs, call_inputs, send_opts={}, startRequest...
FILE: api/getPipeline.py
function listAvailablePipelines (line 15) | def listAvailablePipelines():
function availableCommunityPipelines (line 27) | def availableCommunityPipelines():
function clearPipelines (line 40) | def clearPipelines():
function getPipelineClass (line 51) | def getPipelineClass(pipeline_name: str):
function getPipelineForModel (line 58) | def getPipelineForModel(
FILE: api/getScheduler.py
function initScheduler (line 39) | def initScheduler(MODEL_ID: str, scheduler_id: str, download=False):
function getScheduler (line 65) | def getScheduler(MODEL_ID: str, scheduler_id: str, download=False):
FILE: api/lib/prompts.py
function prepare_prompts (line 4) | def prepare_prompts(pipeline, model_inputs, is_sdxl):
FILE: api/lib/textual_inversions.py
function strMap (line 17) | def strMap(str: str):
function extract_tokens_from_list (line 24) | def extract_tokens_from_list(textual_inversions: list):
function handle_textual_inversions (line 28) | async def handle_textual_inversions(textual_inversions: list, model, sta...
FILE: api/lib/textual_inversions_test.py
class TextualInversionsTest (line 5) | class TextualInversionsTest(unittest.TestCase):
method test_extract_tokens_query_fname (line 6) | def test_extract_tokens_query_fname(self):
method test_extract_tokens_query_token (line 11) | def test_extract_tokens_query_token(self):
FILE: api/loadModel.py
function loadModel (line 28) | def loadModel(
FILE: api/precision.py
function revision_from_precision (line 18) | def revision_from_precision(precision=MODEL_PRECISION):
function torch_dtype_from_precision (line 23) | def torch_dtype_from_precision(precision=MODEL_PRECISION):
function torch_dtype_from_precision (line 29) | def torch_dtype_from_precision(precision=MODEL_PRECISION):
FILE: api/send.py
function get_now (line 19) | def get_now():
function clearSession (line 50) | def clearSession(force=False):
function getTimings (line 60) | def getTimings():
function send (line 74) | async def send(type: str, status: str, payload: dict = {}, opts: dict = ...
FILE: api/server.py
function healthcheck (line 27) | def healthcheck(request):
function inference (line 39) | async def inference(request):
FILE: api/status.py
class Status (line 1) | class Status:
method __init__ (line 2) | def __init__(self):
method update (line 6) | def update(self, type, progress):
method get (line 10) | def get(self):
FILE: api/tests.py
function test_memory_free_on_swap_model (line 4) | def test_memory_free_on_swap_model():
FILE: api/train_dreambooth.py
function send (line 78) | def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
function TrainDreamBooth (line 82) | def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, ...
function save_model_card (line 266) | def save_model_card(
function log_validation (line 304) | def log_validation(
function import_model_class_from_model_name_or_path (line 408) | def import_model_class_from_model_name_or_path(
class DreamBoothDataset (line 436) | class DreamBoothDataset(Dataset):
method __init__ (line 442) | def __init__(
method __len__ (line 502) | def __len__(self):
method __getitem__ (line 505) | def __getitem__(self, index):
function collate_fn (line 551) | def collate_fn(examples, with_prior_preservation=False):
class PromptDataset (line 586) | class PromptDataset(Dataset):
method __init__ (line 589) | def __init__(self, prompt, num_samples):
method __len__ (line 593) | def __len__(self):
method __getitem__ (line 596) | def __getitem__(self, index):
function model_has_vae (line 603) | def model_has_vae(args):
function tokenize_prompt (line 617) | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
function encode_prompt (line 634) | def encode_prompt(
function main (line 653) | def main(args, init_pipeline, send_opts):
FILE: api/utils/storage/BaseStorage.py
class BaseArchive (line 8) | class BaseArchive(ABC):
method __init__ (line 9) | def __init__(self, path, status=None):
method updateStatus (line 13) | def updateStatus(self, type, progress):
method extract (line 17) | def extract(self):
method splitext (line 20) | def splitext(self):
class TarArchive (line 26) | class TarArchive(BaseArchive):
method test (line 28) | def test(path):
method extract (line 31) | def extract(self, dir, dry_run=False):
function Archive (line 63) | def Archive(path, **kwargs):
class BaseStorage (line 69) | class BaseStorage(ABC):
method test (line 72) | def test(url):
method __init__ (line 75) | def __init__(self, url, **kwargs):
method updateStatus (line 80) | def updateStatus(self, type, progress):
method splitext (line 84) | def splitext(self):
method get_filename (line 89) | def get_filename(self):
method download_file (line 93) | def download_file(self, dest):
method download_and_extract (line 97) | def download_and_extract(self, fname, dir=None, dry_run=False):
FILE: api/utils/storage/BaseStorage_test.py
class BaseStorageTest (line 5) | class BaseStorageTest(unittest.TestCase):
method test_get_filename (line 6) | def test_get_filename(self):
class Download_and_extract (line 10) | class Download_and_extract(unittest.TestCase):
method test_file_only (line 11) | def test_file_only(self):
method test_file_archive (line 16) | def test_file_archive(self):
FILE: api/utils/storage/HTTPStorage.py
function get_now (line 10) | def get_now():
class HTTPStorage (line 14) | class HTTPStorage(BaseStorage):
method test (line 16) | def test(url):
method __init__ (line 19) | def __init__(self, url, **kwargs):
method upload_file (line 26) | def upload_file(self, source, dest):
method download_file (line 29) | def download_file(self, fname):
FILE: api/utils/storage/S3Storage.py
function get_now (line 18) | def get_now():
class S3Storage (line 22) | class S3Storage(BaseStorage):
method test (line 23) | def test(url):
method __init__ (line 26) | def __init__(self, url, **kwargs):
method s3resource (line 57) | def s3resource(self):
method s3client (line 68) | def s3client(self):
method bucket (line 79) | def bucket(self):
method upload_file (line 86) | def upload_file(self, source, dest):
method download_file (line 109) | def download_file(self, dest):
method file_exists (line 129) | def file_exists(self):
FILE: api/utils/storage/S3Storage_test.py
class S3StorageTest (line 6) | class S3StorageTest(unittest.TestCase):
method test_endpoint_only_s3 (line 7) | def test_endpoint_only_s3(self):
method test_endpoint_only_http_s3 (line 13) | def test_endpoint_only_http_s3(self):
method test_endpoint_only_https_s3 (line 19) | def test_endpoint_only_https_s3(self):
method test_bucket_only (line 25) | def test_bucket_only(self):
method test_url_with_bucket_and_file_only (line 31) | def test_url_with_bucket_and_file_only(self):
method test_full_url_with_subdirectory (line 37) | def test_full_url_with_subdirectory(self):
FILE: api/utils/storage/__init__.py
function Storage (line 9) | def Storage(url, no_raise=False, **kwargs):
FILE: api/utils/storage/__init__test.py
class StorageTest (line 5) | class StorageTest(unittest.TestCase):
method test_url_s3 (line 6) | def test_url_s3(self):
method test_url_http (line 10) | def test_url_http(self):
method test_no_match_raise (line 14) | def test_no_match_raise(self):
method test_no_match_no_raise (line 18) | def test_no_match_no_raise(self):
FILE: test.py
function b64encode_file (line 28) | def b64encode_file(filename: str):
function output_path (line 38) | def output_path(filename: str):
function sizeof_fmt (line 43) | def sizeof_fmt(num, suffix="B"):
function decode_and_save (line 51) | def decode_and_save(image_byte_string: str, name: str):
function test (line 68) | def test(name, inputs):
function runTest (line 73) | def runTest(name, args, extraCallInputs, extraModelInputs):
function main (line 403) | def main(tests_to_run, args, extraCallInputs, extraModelInputs):
FILE: tests/integration/conftest.py
function my_fixture (line 7) | def my_fixture():
FILE: tests/integration/lib.py
function log_streamer (line 40) | def log_streamer(container, name=None):
function get_free_port (line 109) | def get_free_port():
function startContainer (line 117) | def startContainer(image, command=None, stream_logs=False, onstop=None, ...
function getMinio (line 161) | def getMinio(id="disposable"):
function getDDA (line 256) | def getDDA(
function cleanup (line 357) | def cleanup():
FILE: tests/integration/test_attn_procs.py
class TestAttnProcs (line 8) | class TestAttnProcs:
method setup_class (line 9) | def setup_class(self):
method teardown_class (line 21) | def teardown_class(self):
method test_lora_hf_download (line 26) | def test_lora_hf_download(self):
method test_lora_http_download_pytorch_bin (line 50) | def test_lora_http_download_pytorch_bin(self):
method test_lora_http_download_civitai_safetensors (line 75) | def test_lora_http_download_civitai_safetensors(self):
FILE: tests/integration/test_build_download.py
function test_cloudcache_build_download (line 6) | def test_cloudcache_build_download():
function test_huggingface_build_download (line 51) | def test_huggingface_build_download():
function test_checkpoint_url_build_download (line 97) | def test_checkpoint_url_build_download():
FILE: tests/integration/test_cloud_cache.py
function test_cloud_cache_create_and_upload (line 6) | def test_cloud_cache_create_and_upload():
FILE: tests/integration/test_dreambooth.py
class TestDreamBoothS3 (line 8) | class TestDreamBoothS3:
method setup_class (line 13) | def setup_class(self):
method teardown_class (line 17) | def teardown_class(self):
method test_training_s3 (line 21) | def test_training_s3(self):
method test_s3_download_and_inference (line 48) | def test_s3_download_and_inference(self):
class TestDreamBoothHF (line 73) | class TestDreamBoothHF:
method test_training_hf (line 74) | def test_training_hf(self):
method test_hf_download_and_inference (line 103) | def test_hf_download_and_inference(self):
FILE: tests/integration/test_general.py
class TestGeneralClass (line 7) | class TestGeneralClass:
method setup_class (line 22) | def setup_class(self):
method teardown_class (line 34) | def teardown_class(self):
method test_txt2img (line 39) | def test_txt2img(self):
method test_img2img (line 43) | def test_img2img(self):
FILE: tests/integration/test_loras.py
class TestLoRAs (line 7) | class TestLoRAs:
method setup_class (line 8) | def setup_class(self):
method teardown_class (line 20) | def teardown_class(self):
method test_lora_hf_download (line 27) | def test_lora_hf_download(self):
method test_lora_http_download_pytorch_bin (line 53) | def test_lora_http_download_pytorch_bin(self):
method test_lora_http_download_civitai_safetensors (line 77) | def test_lora_http_download_civitai_safetensors(self):
FILE: tests/integration/test_memory.py
function test_memory (line 7) | def test_memory():
Condensed preview — 73 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (283K chars).
[
{
"path": ".circleci/config.yml",
"chars": 3213,
"preview": "version: 2.1\n\njobs:\n build:\n docker:\n - image: cimg/python:3.9-node\n resource_class: medium\n\n # would have"
},
{
"path": ".devcontainer/devcontainer.json",
"chars": 1489,
"preview": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the\n// README at: https://github.co"
},
{
"path": ".devcontainer/local.example.env",
"chars": 534,
"preview": "# Useful environment variables:\n\n# AWS or S3-compatible storage credentials and buckets\nAWS_ACCESS_KEY_ID=\nAWS_SECRET_AC"
},
{
"path": ".gitignore",
"chars": 1863,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": ".vscode/settings.json",
"chars": 737,
"preview": "{\n \"python.testing.pytestArgs\": [\n \"--cov=.\",\n \"--cov-report=xml\",\n \"--ignore=test.py\",\n \"--ignore=tests/in"
},
{
"path": ".vscode/tasks.json",
"chars": 287,
"preview": "{\n // See https://go.microsoft.com/fwlink/?LinkId=733558\n // for the documentation about the tasks.json format\n \"vers"
},
{
"path": "CHANGELOG.md",
"chars": 31652,
"preview": "# [1.7.0](https://github.com/kiri-art/docker-diffusers-api/compare/v1.6.0...v1.7.0) (2023-09-04)\n\n\n### Bug Fixes\n\n* **ad"
},
{
"path": "CONTRIBUTING.md",
"chars": 6389,
"preview": "# CONTRIBUTING\n\n*Tips for development*\n\n1. [General Hints](#general)\n1. [Development / Editor Setup](#editors)\n 1. [V"
},
{
"path": "Dockerfile",
"chars": 3664,
"preview": "ARG FROM_IMAGE=\"pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime\"\n# ARG FROM_IMAGE=\"gadicc/diffusers-api-base:python3.9-pyt"
},
{
"path": "LICENSE",
"chars": 1075,
"preview": "MIT License\n\nCopyright (c) 2022 Banana, Gadi Cohen\n\nPermission is hereby granted, free of charge, to any person obtainin"
},
{
"path": "README.md",
"chars": 11816,
"preview": "# docker-diffusers-api (\"banana-sd-base\")\n\nDiffusers / Stable Diffusion in docker with a REST API, supporting various mo"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "api/app.py",
"chars": 26511,
"preview": "import asyncio\nfrom sched import scheduler\nimport torch\n\nfrom torch import autocast\nfrom diffusers import __version__\nim"
},
{
"path": "api/convert_to_diffusers.py",
"chars": 4997,
"preview": "import os\nimport requests\nimport subprocess\nimport torch\nimport json\nfrom diffusers.pipelines.stable_diffusion.convert_f"
},
{
"path": "api/device.py",
"chars": 1004,
"preview": "import torch\n\nif torch.cuda.is_available():\n print(\"[device] CUDA (Nvidia) detected\")\n device_id = \"cuda\"\n devi"
},
{
"path": "api/download.py",
"chars": 7101,
"preview": "# In this file, we define download_model\n# It runs during container build time to get model weights built into the conta"
},
{
"path": "api/download_checkpoint.py",
"chars": 680,
"preview": "import os\nfrom utils import Storage\n\nCHECKPOINT_URL = os.environ.get(\"CHECKPOINT_URL\", None)\nCHECKPOINT_DIR = \"/root/.ca"
},
{
"path": "api/extras/__init__.py",
"chars": 31,
"preview": "from .upsample import upsample\n"
},
{
"path": "api/extras/upsample/__init__.py",
"chars": 31,
"preview": "from .upsample import upsample\n"
},
{
"path": "api/extras/upsample/models.py",
"chars": 2602,
"preview": "upsamplers = {\n \"RealESRGAN_x4plus\": {\n \"name\": \"General - RealESRGANplus\",\n \"weights\": \"https://github"
},
{
"path": "api/extras/upsample/upsample.py",
"chars": 6996,
"preview": "import os\nimport asyncio\nfrom pathlib import Path\n\nimport base64\nfrom io import BytesIO\nimport PIL\nimport json\nimport cv"
},
{
"path": "api/getPipeline.py",
"chars": 3537,
"preview": "import time\nimport os, fnmatch\nfrom diffusers import (\n DiffusionPipeline,\n pipelines as diffusers_pipelines,\n)\nfr"
},
{
"path": "api/getScheduler.py",
"chars": 2743,
"preview": "import torch\nimport os\nimport time\nfrom diffusers import schedulers as _schedulers\n\nHF_AUTH_TOKEN = os.getenv(\"HF_AUTH_T"
},
{
"path": "api/lib/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "api/lib/prompts.py",
"chars": 2354,
"preview": "from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType\n\n\ndef prepare_prompts(pipeline, mode"
},
{
"path": "api/lib/textual_inversions.py",
"chars": 3124,
"preview": "import json\nimport re\nimport os\nimport asyncio\nfrom utils import Storage\nfrom .vars import MODELS_DIR\n\nlast_textual_inve"
},
{
"path": "api/lib/textual_inversions_test.py",
"chars": 606,
"preview": "import unittest\nfrom .textual_inversions import extract_tokens_from_list\n\n\nclass TextualInversionsTest(unittest.TestCase"
},
{
"path": "api/lib/vars.py",
"chars": 330,
"preview": "import os\n\nRUNTIME_DOWNLOADS = os.getenv(\"RUNTIME_DOWNLOADS\") == \"1\"\nUSE_DREAMBOOTH = os.getenv(\"USE_DREAMBOOTH\") == \"1\""
},
{
"path": "api/loadModel.py",
"chars": 2628,
"preview": "import torch\nimport os\nfrom diffusers import pipelines as _pipelines, AutoPipelineForText2Image\nfrom getScheduler import"
},
{
"path": "api/precision.py",
"chars": 1054,
"preview": "import os\nimport torch\n\nDEPRECATED_PRECISION = os.getenv(\"PRECISION\")\nMODEL_PRECISION = os.getenv(\"MODEL_PRECISION\") or "
},
{
"path": "api/send.py",
"chars": 3190,
"preview": "import json\nimport os\nimport datetime\nimport time\nimport requests\nimport hashlib\nfrom requests_futures.sessions import F"
},
{
"path": "api/server.py",
"chars": 2196,
"preview": "# Do not edit if deploying to Banana Serverless\n# This file is boilerplate for the http server, and follows a strict int"
},
{
"path": "api/status.py",
"chars": 292,
"preview": "class Status:\n def __init__(self):\n self.type = \"init\"\n self.progress = 0.0\n\n def update(self, type,"
},
{
"path": "api/tests.py",
"chars": 978,
"preview": "from test import runTest\n\n\ndef test_memory_free_on_swap_model():\n \"\"\"\n Make sure memory is freed when swapping mod"
},
{
"path": "api/train_dreambooth.py",
"chars": 55303,
"preview": "# Based on https://github.com/huggingface/diffusers/commits/main/examples/dreambooth/train_dreambooth.py\n# Synced to com"
},
{
"path": "api/utils/__init__.py",
"chars": 29,
"preview": "from .storage import Storage\n"
},
{
"path": "api/utils/storage/BaseStorage.py",
"chars": 3071,
"preview": "import os\nimport re\nimport subprocess\nfrom abc import ABC, abstractmethod\nimport xtarfile as tarfile\n\n\nclass BaseArchive"
},
{
"path": "api/utils/storage/BaseStorage_test.py",
"chars": 887,
"preview": "import unittest\nfrom . import Storage, S3Storage, HTTPStorage\n\n\nclass BaseStorageTest(unittest.TestCase):\n def test_g"
},
{
"path": "api/utils/storage/HTTPStorage.py",
"chars": 1736,
"preview": "import re\nimport os\nimport time\nimport requests\nfrom tqdm import tqdm\nfrom .BaseStorage import BaseStorage\nimport urllib"
},
{
"path": "api/utils/storage/S3Storage.py",
"chars": 4436,
"preview": "import boto3\nimport botocore\nimport re\nimport os\nimport time\nfrom tqdm import tqdm\nfrom botocore.client import Config\nfr"
},
{
"path": "api/utils/storage/S3Storage_test.py",
"chars": 1822,
"preview": "import unittest\nimport os\nfrom .S3Storage import S3Storage, AWS_S3_ENDPOINT_URL, AWS_S3_DEFAULT_BUCKET\n\n\nclass S3Storage"
},
{
"path": "api/utils/storage/__init__.py",
"chars": 396,
"preview": "import os\nimport re\nfrom .S3Storage import S3Storage\nfrom .HTTPStorage import HTTPStorage\n\nclasses = [S3Storage, HTTPSto"
},
{
"path": "api/utils/storage/__init__test.py",
"chars": 624,
"preview": "import unittest\nfrom . import Storage, S3Storage, HTTPStorage\n\n\nclass StorageTest(unittest.TestCase):\n def test_url_s"
},
{
"path": "build",
"chars": 385,
"preview": "#!/bin/sh\n\n# This is my common way of building, but you can build however you like.\n# Note if you using a proxy, you nee"
},
{
"path": "docs/internal_safetensor_cache_flow.md",
"chars": 1248,
"preview": "internal document to gather my thoughts\n\nRUNTIME_DOWNLOADS=1 (must be build arg)\nIMAGE_CLOUD_CACHE=\"s3://\" (can be env a"
},
{
"path": "docs/storage.md",
"chars": 1216,
"preview": "# Storage\n\nMost URLs passed at build args or call args support special URLs, both to\nstore and retrieve files.\n\n**The St"
},
{
"path": "install.sh",
"chars": 826,
"preview": "#!/bin/sh\n\n# This entire file is no longer used but kept around for reference.\n\nif [ \"$FLASH_ATTENTION\" == \"1\" ]; then\n\n"
},
{
"path": "package.json",
"chars": 476,
"preview": "{\n \"name\": \"docker-diffusers-api\",\n \"version\": \"0.0.1\",\n \"main\": \"index.js\",\n \"repository\": \"https://github.com/kiri"
},
{
"path": "prime.sh",
"chars": 1865,
"preview": "#!/bin/sh\n\n# need to fix this.\n#download_model {'model_url': 's3://', 'model_id': 'Linaqruf/anything-v3.0', 'model_revis"
},
{
"path": "release.config.js",
"chars": 611,
"preview": "// https://semantic-release.gitbook.io/semantic-release/support/faq#can-i-use-semantic-release-to-publish-non-javascript"
},
{
"path": "requirements.txt",
"chars": 1553,
"preview": "# we pin sanic==22.6.2 for compatibility with banana\nsanic==22.6.2\nsanic-ext==22.6.2\n# earlier sanics don't pin but requ"
},
{
"path": "run.sh",
"chars": 589,
"preview": "#!/bin/bash\n\ndocker run -it --rm \\\n --gpus all \\\n -p 8000:8000 \\\n -e http_proxy=\"http://172.17.0.1:3128\" \\\n -e http"
},
{
"path": "run_integration_tests_on_lambda.sh",
"chars": 5222,
"preview": "#!/bin/bash\n\nPAYLOAD_FILE=\"/tmp/request.json\"\n\nif [ -z \"$LAMBDA_API_KEY\" ]; then\n echo \"No LAMBDA_API_KEY set\"\n exit 1"
},
{
"path": "scripts/devContainerPostCreate.sh",
"chars": 221,
"preview": "#!/bin/bash\n\n# devcontainer.json postCreateCommand\n\necho\necho Initialize conda bindings for bash\nconda init bash\n\necho A"
},
{
"path": "scripts/devContainerServer.sh",
"chars": 137,
"preview": "#!/bin/bash\n\nsource /opt/conda/bin/activate base\n\nln -sf /api/diffusers .\n\nwatchmedo auto-restart --recursive -d api pyt"
},
{
"path": "scripts/patchmatch-setup.sh",
"chars": 511,
"preview": "#!/bin/sh\n\nif [ \"$USE_PATCHMATCH\" != \"1\" ]; then\n echo \"Skipping PyPatchMatch install because USE_PATCHMATCH=$USE_PATCH"
},
{
"path": "scripts/permutations.yaml",
"chars": 840,
"preview": "list:\n\n - name: sd-v1-5\n HF_AUTH_TOKEN: $HF_AUTH_TOKEN\n MODEL_ID: runwayml/stable-diffusion-v1-5\n PIPELINE: AL"
},
{
"path": "scripts/permute.sh",
"chars": 2610,
"preview": "#!/usr/bin/env bash\n\n# Run this in banana-sd-base's PARENT directory.\n# Modify the below first per your preferences\n\n# R"
},
{
"path": "test.py",
"chars": 15306,
"preview": "# This file is used to verify your http server acts as expected\n# Run it with `python3 test.py``\n\nimport requests\nimport"
},
{
"path": "tests/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/integration/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/integration/conftest.py",
"chars": 787,
"preview": "import pytest\nimport os\nfrom .lib import startContainer, get_free_port, DOCKER_GW_IP\n\n\n@pytest.fixture(autouse=True, sco"
},
{
"path": "tests/integration/lib.py",
"chars": 10281,
"preview": "import pytest\nimport docker\nimport atexit\nimport time\nimport boto3\nimport os\nimport requests\nimport socket\nimport asynci"
},
{
"path": "tests/integration/requirements.txt",
"chars": 179,
"preview": "pytest==7.2.0\ndocker==6.0.1\nboto3==1.26.44\nPillow==9.4.0\n# work around breaking changes in urllib3 2.0\n# until https://g"
},
{
"path": "tests/integration/test_attn_procs.py",
"chars": 3168,
"preview": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\nif False:\n\n class TestAttnProcs:\n "
},
{
"path": "tests/integration/test_build_download.py",
"chars": 3791,
"preview": "import sys\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\ndef test_cloudcache_build_download():\n \"\"\"\n "
},
{
"path": "tests/integration/test_cloud_cache.py",
"chars": 940,
"preview": "import sys\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\ndef test_cloud_cache_create_and_upload():\n \"\""
},
{
"path": "tests/integration/test_dreambooth.py",
"chars": 3495,
"preview": "import os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\nHF_USERNAME = os.getenv(\"HF_USERNAME\", \"gadicc\")\n\n"
},
{
"path": "tests/integration/test_general.py",
"chars": 1484,
"preview": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\nclass TestGeneralClass:\n \"\"\"\n Ty"
},
{
"path": "tests/integration/test_loras.py",
"chars": 4435,
"preview": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\nclass TestLoRAs:\n def setup_class(s"
},
{
"path": "tests/integration/test_memory.py",
"chars": 1246,
"preview": "import sys\nimport os\nfrom .lib import getMinio, getDDA\nfrom test import runTest\n\n\ndef test_memory():\n \"\"\"\n Make su"
},
{
"path": "touch",
"chars": 0,
"preview": ""
},
{
"path": "update.sh",
"chars": 50,
"preview": "#!/bin/sh\n\nrsync -avzPe \"ssh -p $1\" api/ $2:/api/\n"
}
]
About this extraction
This page contains the full source code of the kiri-art/docker-diffusers-api GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 73 files (261.2 KB), approximately 68.2k tokens, and a symbol index with 157 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.