[
  {
    "path": ".dockerignore",
    "content": "*\n!rembg\n!pyproject.toml\n!poetry.lock\n!README.md\n!.git\n.env\n"
  },
  {
    "path": ".editorconfig",
    "content": "# https://editorconfig.org/\n\nroot = true\n\n[*]\nindent_style = space\nindent_size = 4\ninsert_final_newline = true\ntrim_trailing_whitespace = true\nend_of_line = lf\ncharset = utf-8\n"
  },
  {
    "path": ".gitattributes",
    "content": "rembg/_version.py export-subst\n"
  },
  {
    "path": ".github/FUNDING.yml",
    "content": "github: [danielgatis]\ncustom: [\"https://www.buymeacoffee.com/danielgatis\"]\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: \"[BUG] ...\"\nlabels: bug\nassignees: \"\"\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Images**\nInput images to reproduce.\n\n**OS Version:**\niOS 22\n\n**Rembg version:**\nv2.0.21\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: \"[FEATURE] ...\"\nlabels: enhancement\nassignees: \"\"\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/workflows/close_inactive_issues.yml",
    "content": "name: Close inactive issues\n\non:\n    schedule:\n        - cron: \"30 1 * * *\"\n\njobs:\n    close_inactive_issues:\n        runs-on: ubuntu-latest\n        permissions:\n            issues: write\n            pull-requests: write\n        steps:\n            - uses: actions/stale@v9\n              with:\n                  days-before-issue-stale: 30\n                  days-before-issue-close: 14\n                  stale-issue-label: \"stale\"\n                  stale-issue-message: \"This issue is stale because it has been open for 30 days with no activity.\"\n                  close-issue-message: \"This issue was closed because it has been inactive for 14 days since being marked as stale.\"\n                  days-before-pr-stale: -1\n                  days-before-pr-close: -1\n                  repo-token: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/lint_python.yml",
    "content": "name: Lint\n\non:\n  push:\n    branches:\n      - \"**\"\n  pull_request:\n\njobs:\n    lint_python:\n        runs-on: ubuntu-latest\n        steps:\n            - uses: actions/checkout@v4\n            - uses: actions/setup-python@v5\n            - name: Install Poetry\n              uses: snok/install-poetry@v1\n              with:\n                  virtualenvs-create: true\n                  virtualenvs-in-project: true\n            - name: Install dependencies\n              run: poetry install --with dev --extras \"cpu cli\"\n            - run: poetry run mypy --install-types --non-interactive --ignore-missing-imports ./rembg\n            - run: poetry run bandit --recursive --skip B101,B104,B310,B311,B303,B110 --exclude ./rembg/_version.py ./rembg\n            - run: poetry run black --force-exclude rembg/_version.py --check --diff ./rembg\n            - run: poetry run flake8 ./rembg --count --ignore=B008,C901,E203,E266,E731,F401,F811,F841,W503,E501,E402 --show-source --statistics --exclude ./rembg/_version.py\n            - run: poetry run isort --check-only --profile black ./rembg\n"
  },
  {
    "path": ".github/workflows/publish_docker.yml",
    "content": "name: Publish Docker image\n\non:\n  push:\n    tags:\n      - \"v*.*.*\"\n\njobs:\n  publish_docker:\n    name: Push Docker image to Docker Hub\n    runs-on: ubuntu-24.04\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n\n      - name: Docker meta\n        id: meta\n        uses: docker/metadata-action@v5\n        with:\n          # list of Docker images to use as base name for tags\n          images: |\n            ${{ secrets.DOCKER_HUB_USERNAME }}/rembg\n          # generate Docker tags based on the following events/attributes\n          tags: |\n            type=ref,event=branch\n            type=ref,event=branch\n            type=ref,event=pr\n            type=semver,pattern={{version}}\n            type=semver,pattern={{major}}.{{minor}}\n            type=semver,pattern={{major}}\n            type=sha\n\n      - name: Set up QEMU\n        uses: docker/setup-qemu-action@v3\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKER_HUB_USERNAME }}\n          password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}\n\n      - name: Build and push\n        uses: docker/build-push-action@v6\n        with:\n          context: .\n          platforms: linux/amd64\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: ${{ steps.meta.outputs.tags }}\n          labels: ${{ steps.meta.outputs.labels }}\n          cache-from: type=registry,ref=${{ secrets.DOCKER_HUB_USERNAME }}/rembg:buildcache\n          cache-to: type=registry,ref=${{ secrets.DOCKER_HUB_USERNAME }}/rembg:buildcache,mode=max\n"
  },
  {
    "path": ".github/workflows/publish_pypi.yml",
    "content": "name: Publish to Pypi\n\non:\n    push:\n        tags:\n            - \"v*.*.*\"\n\njobs:\n    publish_pypi:\n        runs-on: ubuntu-latest\n        steps:\n            - uses: actions/checkout@v4\n              with:\n                  fetch-depth: 0\n            - uses: actions/setup-python@v5\n            - name: Install Poetry\n              uses: snok/install-poetry@v1\n              with:\n                  virtualenvs-create: true\n                  virtualenvs-in-project: true\n            - name: Install dependencies\n              run: |\n                  poetry self add \"poetry-dynamic-versioning[plugin]\"\n                  poetry install --with dev --extras \"cpu cli\"\n            - name: Build and publish to PyPI\n              run: |\n                  poetry build\n                  poetry publish\n              env:\n                  POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PIPY_PASSWORD }}\n\n    test_install:\n        needs: publish_pypi\n        runs-on: ubuntu-latest\n        strategy:\n            matrix:\n                python-version: [\"3.13\"]\n\n        steps:\n            - uses: actions/checkout@v4\n            - name: Set up Python ${{ matrix.python-version }}\n              uses: actions/setup-python@v5\n              with:\n                  python-version: ${{ matrix.python-version }}\n            - name: Wait for PyPI to update\n              run: sleep 60\n            - name: Install from PyPI\n              run: pip install rembg[cpu,cli]\n            - name: Test installation\n              run: |\n                  attempt=0\n                  until rembg d || [ $attempt -eq 5 ]; do\n                      attempt=$((attempt+1))\n                      echo \"Attempt $attempt to download the models...\"\n                  done\n                  if [ $attempt -eq 5 ]; then\n                      echo \"downloading the models failed 5 times, exiting...\"\n                      exit 1\n                  fi\n                  rembg --version\n"
  },
  {
    "path": ".github/workflows/windows_installer.yml",
    "content": "name: Build Windows Installer\n\non:\n    push:\n        tags:\n            - \"v*.*.*\"\njobs:\n  windows_installer:\n    name: Build the Inno Setup Installer\n    runs-on: windows-latest\n    steps:\n      - uses: actions/setup-python@v5\n      - uses: actions/checkout@v4\n      - shell: pwsh\n        run: ./_build-exe.ps1\n      - name: Compile CPU Installer\n        uses: Minionguyjpro/Inno-Setup-Action@v1.2.2\n        with:\n          path: _setup-cpu.iss\n          options: /O+\n      - name: Compile GPU Installer\n        uses: Minionguyjpro/Inno-Setup-Action@v1.2.2\n        with:\n          path: _setup-gpu.iss\n          options: /O+\n      - name: Upload CPU installer to release\n        uses: svenstaro/upload-release-action@v2\n        with:\n            repo_token: ${{ secrets.GITHUB_TOKEN }}\n            file: dist/rembg-cli-cpu-installer.exe\n            asset_name: rembg-cli-cpu-installer.exe\n            tag: ${{ github.ref }}\n            overwrite: true\n      - name: Upload GPU installer to release\n        uses: svenstaro/upload-release-action@v2\n        with:\n            repo_token: ${{ secrets.GITHUB_TOKEN }}\n            file: dist/rembg-cli-gpu-installer.exe\n            asset_name: rembg-cli-gpu-installer.exe\n            tag: ${{ github.ref }}\n            overwrite: true\n"
  },
  {
    "path": ".gitignore",
    "content": "# general things to ignore\nbuild/\ndist/\n.venv/\n.direnv/\n*.egg-info/\n*.egg\n*.py[cod]\n__pycache__/\n*.so\n*~≈\n.env\n.envrc\n.idea\n.pytest_cache\n\n# due to using tox and pytest\n.tox\n.cache\n.mypy_cache\n\n# Poetry\n# For libraries, poetry.lock is often not committed\n# For applications, it should be committed\npoetry.lock\n"
  },
  {
    "path": ".markdownlint.yaml",
    "content": "---\ndefault: true\nMD013: false # line-length\nMD033: false # no-inline-html\n"
  },
  {
    "path": ".python-version",
    "content": "3.13.9\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\ntitle: rembg\nmessage: Rembg is a tool to remove images background\ntype: software\nauthors:\n  - given-names: Daniel\n    family-names: Gatis\n    email: danielgatis@gmail.com\nidentifiers:\n  - type: url\n    value: 'https://github.com/danielgatis'\nrepository-code: 'https://github.com/danielgatis/rembg'\nurl: 'https://github.com/danielgatis/rembg'\nabstract: Rembg is a tool to remove images background.\nlicense: MIT\ncommit: 9079508935ae55d6eefa0fd75f870599640e8593\nversion: 2.0.66\ndate-released: '2025-02-21'\n\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM python:3.11-slim\n\nWORKDIR /rembg\n\nRUN pip install --upgrade pip && \\\n    pip install poetry poetry-dynamic-versioning\n\nRUN apt-get update && apt-get install -y curl git && apt-get clean && rm -rf /var/lib/apt/lists/*\n\nCOPY . .\n\nRUN poetry config virtualenvs.create false && \\\n    poetry install --extras \"cpu cli\" --without dev\n\nRUN rembg d u2net\n\nEXPOSE 7000\nENTRYPOINT [\"rembg\"]\nCMD [\"--help\"]\n"
  },
  {
    "path": "Dockerfile_nvidia_cuda_cudnn_gpu",
    "content": "FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04\n\nWORKDIR /rembg\n\nRUN apt-get update && apt-get install -y --no-install-recommends python3-pip python-is-python3 curl && apt-get clean && rm -rf /var/lib/apt/lists/*\n\nCOPY . .\n\nRUN python -m pip install \".[gpu,cli]\" --break-system-packages\nRUN rembg d u2net\n\nEXPOSE 7000\nENTRYPOINT [\"rembg\"]\nCMD [\"--help\"]\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT License\n\nCopyright (c) 2020 Daniel Gatis\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include LICENSE.txt\ninclude README.md\ninclude pyproject.toml\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <img src=\"logo.png\" alt=\"Rembg Logo\" width=\"600\" />\n</p>\n\n<div align=\"center\">\n  <p align=\"center\">Rembg is a tool to remove image backgrounds. It can be used as a CLI, Python library, HTTP server, or Docker container.</p>\n  <div style=\"display: flex; flex-direction: row; justify-content: center; gap: 8px; flex-wrap: wrap; margin-top: 8px;\">\n    <a href=\"https://img.shields.io/badge/License-MIT-blue.svg\"><img src=\"https://img.shields.io/badge/License-MIT-blue.svg\" alt=\"License\" /></a>\n    <a href=\"https://huggingface.co/spaces/KenjieDec/RemBG\"><img src=\"https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-blue\" alt=\"Hugging Face Spaces\" /></a>\n    <a href=\"https://bgremoval.streamlit.app/\"><img src=\"https://img.shields.io/badge/🎈%20Streamlit%20Community-Cloud-blue\" alt=\"Streamlit App\" /></a>\n    <a href=\"https://colab.research.google.com/github/danielgatis/rembg/blob/main/rembg.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open in Colab\" /></a>\n    <a href=\"https://repomapr.com/danielgatis/rembg\"><img src=\"https://img.shields.io/badge/RepoMapr-View_Interactive_Diagram-blue?style=flat&logo=github\" alt=\"RepoMapr\" /></a>\n  </div>\n</div>\n\n<br/>\n\n<p align=\"center\">\n    <a href=\"https://trendshift.io/repositories/2846\" target=\"_blank\">\n        <img src=\"https://trendshift.io/api/badge/repositories/2846\" alt=\"danielgatis%2Frembg | Trendshift\" style=\"width: 250px; height: 55px;\" width=\"250\" height=\"55\"/>\n    </a>\n</p>\n\n## Sponsors\n\n<table>\n <tr>\n    <td align=\"center\" vertical-align=\"center\">\n      <a href=\"https://photoroom.com/api/remove-background?utm_source=rembg&utm_medium=github_webpage&utm_campaign=sponsor\" >\n        <img src=\"https://font-cdn.photoroom.com/media/api-logo.png\" width=\"120px;\" alt=\"Unsplash\" />\n      </a>\n    </td>\n    <td align=\"center\" vertical-align=\"center\">\n      <b>PhotoRoom Remove Background API</b>\n      <br />\n      <a href=\"https://photoroom.com/api/remove-background?utm_source=rembg&utm_medium=github_webpage&utm_campaign=sponsor\">https://photoroom.com/api</a>\n      <br />\n      <p width=\"200px\">\n        Fast and accurate background remover API<br/>\n      </p>\n    </td>\n  </tr>\n</table>\n\n**If this project has helped you, please consider making a [donation](https://www.buymeacoffee.com/danielgatis).**\n\n## Requirements\n\n```text\npython: >=3.11, <3.14\n```\n\n## Installation\n\nChoose **one** of the following backends based on your hardware:\n\n### CPU support\n\n```bash\npip install \"rembg[cpu]\" # for library\npip install \"rembg[cpu,cli]\" # for library + cli\n```\n\n### GPU support (NVIDIA/CUDA)\n\nFirst, check if your system supports `onnxruntime-gpu` by visiting [onnxruntime.ai](https://onnxruntime.ai/getting-started) and reviewing the installation matrix.\n\n<p style=\"display: flex;align-items: center;justify-content: center;\">\n  <img alt=\"onnxruntime-installation-matrix\" src=\"./onnxruntime-installation-matrix.png\" width=\"400\" />\n</p>\n\nIf your system is compatible, run:\n\n```bash\npip install \"rembg[gpu]\" # for library\npip install \"rembg[gpu,cli]\" # for library + cli\n```\n\n> **Note:** NVIDIA GPUs may require `onnxruntime-gpu`, CUDA, and `cudnn-devel`. See [#668](https://github.com/danielgatis/rembg/issues/668#issuecomment-2689830314) for details. If `rembg[gpu]` doesn't work and you can't install CUDA or `cudnn-devel`, use `rembg[cpu]` with `onnxruntime` instead.\n\n### GPU support (AMD/ROCm)\n\nROCm support requires the `onnxruntime-rocm` package. Install it by following [AMD's documentation](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-onnx.html).\n\nOnce `onnxruntime-rocm` is installed and working, install rembg with ROCm support:\n\n```bash\npip install \"rembg[rocm]\" # for library\npip install \"rembg[rocm,cli]\" # for library + cli\n```\n\n## Usage as a CLI\n\nAfter installation, you can use rembg by typing `rembg` in your terminal.\n\nThe `rembg` command has 4 subcommands, one for each input type:\n\n- `i` - single files\n- `p` - folders (batch processing)\n- `s` - HTTP server\n- `b` - RGB24 pixel binary stream\n\nYou can get help about the main command using:\n\n```shell\nrembg --help\n```\n\nYou can also get help for any subcommand:\n\n```shell\nrembg <COMMAND> --help\n```\n\n### rembg `i`\n\nUsed for processing single files.\n\n**Remove background from a remote image:**\n\n```shell\ncurl -s http://input.png | rembg i > output.png\n```\n\n**Remove background from a local file:**\n\n```shell\nrembg i path/to/input.png path/to/output.png\n```\n\n**Specify a model:**\n\n```shell\nrembg i -m u2netp path/to/input.png path/to/output.png\n```\n\n**Return only the mask:**\n\n```shell\nrembg i -om path/to/input.png path/to/output.png\n```\n\n**Apply alpha matting:**\n\n```shell\nrembg i -a path/to/input.png path/to/output.png\n```\n\n**Pass extra parameters (SAM example):**\n\n```shell\nrembg i -m sam -x '{ \"sam_prompt\": [{\"type\": \"point\", \"data\": [724, 740], \"label\": 1}] }' examples/plants-1.jpg examples/plants-1.out.png\n```\n\n**Pass extra parameters (custom model):**\n\n```shell\nrembg i -m u2net_custom -x '{\"model_path\": \"~/.u2net/u2net.onnx\"}' path/to/input.png path/to/output.png\n```\n\n### rembg `p`\n\nUsed for batch processing entire folders.\n\n**Process all images in a folder:**\n\n```shell\nrembg p path/to/input path/to/output\n```\n\n**Watch mode (process new/changed files automatically):**\n\n```shell\nrembg p -w path/to/input path/to/output\n```\n\n### rembg `s`\n\nUsed to start an HTTP server.\n\n```shell\nrembg s --host 0.0.0.0 --port 7000 --log_level info\n```\n\nFor complete API documentation, visit: `http://localhost:7000/api`\n\n**Remove background from an image URL:**\n\n```shell\ncurl -s \"http://localhost:7000/api/remove?url=http://input.png\" -o output.png\n```\n\n**Remove background from an uploaded image:**\n\n```shell\ncurl -s -F file=@/path/to/input.jpg \"http://localhost:7000/api/remove\" -o output.png\n```\n\n### rembg `b`\n\nProcess a sequence of RGB24 images from stdin. This is intended to be used with programs like FFmpeg that output RGB24 pixel data to stdout.\n\n```shell\nrembg b <width> <height> -o <output_specifier>\n```\n\n**Arguments:**\n\n| Argument | Description |\n|----------|-------------|\n| `width` | Width of input image(s) |\n| `height` | Height of input image(s) |\n| `output_specifier` | Printf-style specifier for output filenames (e.g., `output-%03u.png` produces `output-000.png`, `output-001.png`, etc.). Omit to write to stdout. |\n\n**Example with FFmpeg:**\n\n```shell\nffmpeg -i input.mp4 -ss 10 -an -f rawvideo -pix_fmt rgb24 pipe:1 | rembg b 1280 720 -o folder/output-%03u.png\n```\n\n> **Note:** The width and height must match FFmpeg's output dimensions. The flags `-an -f rawvideo -pix_fmt rgb24 pipe:1` are required for FFmpeg compatibility.\n\n## Usage as a Library\n\n**Input and output as bytes:**\n\n```python\nfrom rembg import remove\n\nwith open('input.png', 'rb') as i:\n    with open('output.png', 'wb') as o:\n        input = i.read()\n        output = remove(input)\n        o.write(output)\n```\n\n**Input and output as a PIL image:**\n\n```python\nfrom rembg import remove\nfrom PIL import Image\n\ninput = Image.open('input.png')\noutput = remove(input)\noutput.save('output.png')\n```\n\n**Input and output as a NumPy array:**\n\n```python\nfrom rembg import remove\nimport cv2\n\ninput = cv2.imread('input.png')\noutput = remove(input)\ncv2.imwrite('output.png', output)\n```\n\n**Force output as bytes:**\n\n```python\nfrom rembg import remove\n\nwith open('input.png', 'rb') as i:\n    with open('output.png', 'wb') as o:\n        input = i.read()\n        output = remove(input, force_return_bytes=True)\n        o.write(output)\n```\n\n**Batch processing with session reuse (recommended for performance):**\n\n```python\nfrom pathlib import Path\nfrom rembg import remove, new_session\n\nsession = new_session()\n\nfor file in Path('path/to/folder').glob('*.png'):\n    input_path = str(file)\n    output_path = str(file.parent / (file.stem + \".out.png\"))\n\n    with open(input_path, 'rb') as i:\n        with open(output_path, 'wb') as o:\n            input = i.read()\n            output = remove(input, session=session)\n            o.write(output)\n```\n\nFor more examples, see the [examples](USAGE.md) page.\n\n## Usage with Docker\n\n### CPU Only\n\nReplace the `rembg` command with `docker run danielgatis/rembg`:\n\n```shell\ndocker run -v .:/data danielgatis/rembg i /data/input.png /data/output.png\n```\n\n### NVIDIA CUDA GPU Acceleration\n\n**Requirements:** Your host must have the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) installed.\n\nCUDA acceleration requires `cudnn-devel`, so you need to build the Docker image yourself. See [#668](https://github.com/danielgatis/rembg/issues/668#issuecomment-2689914205) for details.\n\n**Build the image:**\n\n```shell\ndocker build -t rembg-nvidia-cuda-cudnn-gpu -f Dockerfile_nvidia_cuda_cudnn_gpu .\n```\n\n> **Note:** This image requires ~11GB of disk space (CPU version is ~1.6GB). Models are not included.\n\n**Run the container:**\n\n```shell\nsudo docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/data rembg-nvidia-cuda-cudnn-gpu i -m birefnet-general /data/input.png /data/output.png\n```\n\n**Tips:**\n\n- You can create your own NVIDIA CUDA image and install `rembg[gpu,cli]` in it.\n- Use `-v /path/to/models/:/root/.u2net` to store model files outside the container, avoiding re-downloads.\n\n## Models\n\nAll models are automatically downloaded and saved to `~/.u2net/` on first use.\n\n### Available Models\n\n- u2net ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for general use cases.\n- u2netp ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A lightweight version of u2net model.\n- u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation.\n- u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body.\n- silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb.\n- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases.\n- isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character.\n- sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases.\n- birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases.\n- birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases.\n- birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits.\n- birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS).\n- birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD).\n- birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD).\n- birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset.\n- bria-rmbg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx), [source](https://huggingface.co/briaai/RMBG-2.0)): A state-of-the-art background removal model by BRIA AI.\n\n## Environment Variables\n\n| Variable | Description |\n|----------|-------------|\n| `U2NET_HOME` | Path to the directory where models are stored. Defaults to `$XDG_DATA_HOME/.u2net` (or `~/.u2net` if `XDG_DATA_HOME` is not set). |\n| `XDG_DATA_HOME` | Base data directory used when `U2NET_HOME` is not set. Defaults to `~`. |\n| `MODEL_CHECKSUM_DISABLED` | When set (e.g. `MODEL_CHECKSUM_DISABLED=1`), disables hash verification for downloaded models. This is useful if you want to use your own custom/converted model files without rembg re-downloading the originals. |\n| `OMP_NUM_THREADS` | Sets the number of threads used by ONNX Runtime for inference. |\n\n### Using custom model files\n\nIf you need to use a modified version of a model (e.g. converted to a different ONNX IR version for compatibility with an older CUDA toolkit), you can prevent rembg from overwriting it:\n\n1. Set `MODEL_CHECKSUM_DISABLED=1`\n2. Place your custom `.onnx` file in the models directory (`~/.u2net/` by default) with the expected filename (e.g. `u2net.onnx`)\n3. Rembg will detect the file exists and use it without re-downloading\n\n## FAQ\n\n### When will this library support Python version 3.xx?\n\nThis library depends on [onnxruntime](https://pypi.org/project/onnxruntime). Python version support is determined by onnxruntime's compatibility.\n\n## Support\n\nIf you find this project useful, consider buying me a coffee (or a beer):\n\n<a href=\"https://www.buymeacoffee.com/danielgatis\" target=\"_blank\"><img src=\"https://bmc-cdn.nyc3.digitaloceanspaces.com/BMC-button-images/custom_images/orange_img.png\" alt=\"Buy Me A Coffee\" style=\"height: auto !important;width: auto !important;\"></a>\n\n## Star History\n\n[![Star History Chart](https://api.star-history.com/svg?repos=danielgatis/rembg&type=Date)](https://star-history.com/#danielgatis/rembg&Date)\n\n## License\n\nCopyright (c) 2020-present [Daniel Gatis](https://github.com/danielgatis)\n\nLicensed under the [MIT License](./LICENSE.txt).\n"
  },
  {
    "path": "USAGE.md",
    "content": "# How to use the remove function\n\n## Load the Image\n\n```python\nfrom PIL import Image\nfrom rembg import new_session, remove\n\ninput_path = 'input.png'\noutput_path = 'output.png'\n\ninput = Image.open(input_path)\n```\n\n## Removing the background\n\n### Without additional arguments\n\nThis defaults to the `u2net` model.\n\n```python\noutput = remove(input)\noutput.save(output_path)\n```\n\n### With a specific model\n\nYou can use the `new_session` function to create a session with a specific model.\n\n```python\nmodel_name = \"isnet-general-use\"\nsession = new_session(model_name)\noutput = remove(input, session=session)\n```\n\n### For processing multiple image files\n\nBy default, `remove` initialises a new session every call. This can be a large bottleneck if you're having to process multiple images. Initialise a session and pass it in to the `remove` function for fast multi-image support\n\n```python\nmodel_name = \"unet\"\nrembg_session = new_session(model_name)\nfor img in images:\n    output = remove(img, session=rembg_session)\n```\n\n### With alpha matting\n\nAlpha matting is a post processing step that can be used to improve the quality of the output.\n\n```python\noutput = remove(input, alpha_matting=True, alpha_matting_foreground_threshold=270,alpha_matting_background_threshold=20, alpha_matting_erode_size=11)\n```\n\n### Only mask\n\nIf you only want the mask, you can use the `only_mask` argument.\n\n```python\noutput = remove(input, only_mask=True)\n```\n\n### With post processing\n\nYou can use the `post_process_mask` argument to post process the mask to get better results.\n\n```python\noutput = remove(input, post_process_mask=True)\n```\n\n### Replacing the background color\n\nYou can use the `bgcolor` argument to replace the background color.\n\n```python\noutput = remove(input, bgcolor=(255, 255, 255, 255))\n```\n\n### Using input points\n\nYou can use the `input_points` and `input_labels` arguments to specify the points that should be used for the masks. This only works with the `sam` model.\n\n```python\nimport numpy as np\n# Define the points and labels\n# The points are defined as [y, x]\ninput_points = np.array([[400, 350], [700, 400], [200, 400]])\ninput_labels = np.array([1, 1, 2])\n\nimage = remove(image,session=session, input_points=input_points, input_labels=input_labels)\n```\n\n## Save the image\n\n```python\noutput.save(output_path)\n```\n"
  },
  {
    "path": "_build-exe.ps1",
    "content": "# Install Poetry if not already installed\nif (-not (Get-Command poetry -ErrorAction SilentlyContinue)) {\n    pip install poetry\n}\n\n# Build CPU version\nWrite-Host \"Building CPU version...\" -ForegroundColor Cyan\npoetry install --extras \"cli cpu\"\npoetry run pip install pyinstaller\npoetry run pyinstaller rembg.spec\nRename-Item -Path \"dist/rembg\" -NewName \"rembg-cpu\"\n\n# Build GPU version\nWrite-Host \"Building GPU version...\" -ForegroundColor Cyan\npoetry install --extras \"cli gpu\"\npoetry run pip install pyinstaller\npoetry run pyinstaller rembg.spec --noconfirm\nRename-Item -Path \"dist/rembg\" -NewName \"rembg-gpu\"\n\nWrite-Host \"Build complete!\" -ForegroundColor Green\nWrite-Host \"CPU version: dist/rembg-cpu\"\nWrite-Host \"GPU version: dist/rembg-gpu\"\n"
  },
  {
    "path": "_modpath.iss",
    "content": "// ----------------------------------------------------------------------------\n//\n// Inno Setup Ver:\t5.4.2\n// Script Version:\t1.4.2\n// Author:\t\t\tJared Breland <jbreland@legroom.net>\n// Homepage:\t\thttp://www.legroom.net/software\n// License:\t\t\tGNU Lesser General Public License (LGPL), version 3\n//\t\t\t\t\t\thttp://www.gnu.org/licenses/lgpl.html\n//\n// Script Function:\n//\tAllow modification of environmental path directly from Inno Setup installers\n//\n// Instructions:\n//\tCopy modpath.iss to the same directory as your setup script\n//\n//\tAdd this statement to your [Setup] section\n//\t\tChangesEnvironment=true\n//\n//\tAdd this statement to your [Tasks] section\n//\tYou can change the Description or Flags\n//\tYou can change the Name, but it must match the ModPathName setting below\n//\t\tName: modifypath; Description: &Add application directory to your environmental path; Flags: unchecked\n//\n//\tAdd the following to the end of your [Code] section\n//\tModPathName defines the name of the task defined above\n//\tModPathType defines whether the 'user' or 'system' path will be modified;\n//\t\tthis will default to user if anything other than system is set\n//\tsetArrayLength must specify the total number of dirs to be added\n//\tResult[0] contains first directory, Result[1] contains second, etc.\n//\t\tconst\n//\t\t\tModPathName = 'modifypath';\n//\t\t\tModPathType = 'user';\n//\n//\t\tfunction ModPathDir(): TArrayOfString;\n//\t\tbegin\n//\t\t\tsetArrayLength(Result, 1);\n//\t\t\tResult[0] := ExpandConstant('{app}');\n//\t\tend;\n//\t\t#include \"modpath.iss\"\n// ----------------------------------------------------------------------------\n\nprocedure ModPath();\nvar\n\toldpath:\tString;\n\tnewpath:\tString;\n\tupdatepath:\tBoolean;\n\tpathArr:\tTArrayOfString;\n\taExecFile:\tString;\n\taExecArr:\tTArrayOfString;\n\ti, d:\t\tInteger;\n\tpathdir:\tTArrayOfString;\n\tregroot:\tInteger;\n\tregpath:\tString;\n\nbegin\n\t// Get constants from main script and adjust behavior accordingly\n\t// ModPathType MUST be 'system' or 'user'; force 'user' if invalid\n\tif ModPathType = 'system' then begin\n\t\tregroot := HKEY_LOCAL_MACHINE;\n\t\tregpath := 'SYSTEM\\CurrentControlSet\\Control\\Session Manager\\Environment';\n\tend else begin\n\t\tregroot := HKEY_CURRENT_USER;\n\t\tregpath := 'Environment';\n\tend;\n\n\t// Get array of new directories and act on each individually\n\tpathdir := ModPathDir();\n\tfor d := 0 to GetArrayLength(pathdir)-1 do begin\n\t\tupdatepath := true;\n\n\t\t// Modify WinNT path\n\t\tif UsingWinNT() = true then begin\n\n\t\t\t// Get current path, split into an array\n\t\t\tRegQueryStringValue(regroot, regpath, 'Path', oldpath);\n\t\t\toldpath := oldpath + ';';\n\t\t\ti := 0;\n\n\t\t\twhile (Pos(';', oldpath) > 0) do begin\n\t\t\t\tSetArrayLength(pathArr, i+1);\n\t\t\t\tpathArr[i] := Copy(oldpath, 0, Pos(';', oldpath)-1);\n\t\t\t\toldpath := Copy(oldpath, Pos(';', oldpath)+1, Length(oldpath));\n\t\t\t\ti := i + 1;\n\n\t\t\t\t// Check if current directory matches app dir\n\t\t\t\tif pathdir[d] = pathArr[i-1] then begin\n\t\t\t\t\t// if uninstalling, remove dir from path\n\t\t\t\t\tif IsUninstaller() = true then begin\n\t\t\t\t\t\tcontinue;\n\t\t\t\t\t// if installing, flag that dir already exists in path\n\t\t\t\t\tend else begin\n\t\t\t\t\t\tupdatepath := false;\n\t\t\t\t\tend;\n\t\t\t\tend;\n\n\t\t\t\t// Add current directory to new path\n\t\t\t\tif i = 1 then begin\n\t\t\t\t\tnewpath := pathArr[i-1];\n\t\t\t\tend else begin\n\t\t\t\t\tnewpath := newpath + ';' + pathArr[i-1];\n\t\t\t\tend;\n\t\t\tend;\n\n\t\t\t// Append app dir to path if not already included\n\t\t\tif (IsUninstaller() = false) AND (updatepath = true) then\n\t\t\t\tnewpath := newpath + ';' + pathdir[d];\n\n\t\t\t// Write new path\n\t\t\tRegWriteStringValue(regroot, regpath, 'Path', newpath);\n\n\t\t// Modify Win9x path\n\t\tend else begin\n\n\t\t\t// Convert to shortened dirname\n\t\t\tpathdir[d] := GetShortName(pathdir[d]);\n\n\t\t\t// If autoexec.bat exists, check if app dir already exists in path\n\t\t\taExecFile := 'C:\\AUTOEXEC.BAT';\n\t\t\tif FileExists(aExecFile) then begin\n\t\t\t\tLoadStringsFromFile(aExecFile, aExecArr);\n\t\t\t\tfor i := 0 to GetArrayLength(aExecArr)-1 do begin\n\t\t\t\t\tif IsUninstaller() = false then begin\n\t\t\t\t\t\t// If app dir already exists while installing, skip add\n\t\t\t\t\t\tif (Pos(pathdir[d], aExecArr[i]) > 0) then\n\t\t\t\t\t\t\tupdatepath := false;\n\t\t\t\t\t\t\tbreak;\n\t\t\t\t\tend else begin\n\t\t\t\t\t\t// If app dir exists and = what we originally set, then delete at uninstall\n\t\t\t\t\t\tif aExecArr[i] = 'SET PATH=%PATH%;' + pathdir[d] then\n\t\t\t\t\t\t\taExecArr[i] := '';\n\t\t\t\t\tend;\n\t\t\t\tend;\n\t\t\tend;\n\n\t\t\t// If app dir not found, or autoexec.bat didn't exist, then (create and) append to current path\n\t\t\tif (IsUninstaller() = false) AND (updatepath = true) then begin\n\t\t\t\tSaveStringToFile(aExecFile, #13#10 + 'SET PATH=%PATH%;' + pathdir[d], True);\n\n\t\t\t// If uninstalling, write the full autoexec out\n\t\t\tend else begin\n\t\t\t\tSaveStringsToFile(aExecFile, aExecArr, False);\n\t\t\tend;\n\t\tend;\n\tend;\nend;\n\n// Split a string into an array using passed delimiter\nprocedure MPExplode(var Dest: TArrayOfString; Text: String; Separator: String);\nvar\n\ti: Integer;\nbegin\n\ti := 0;\n\trepeat\n\t\tSetArrayLength(Dest, i+1);\n\t\tif Pos(Separator,Text) > 0 then\tbegin\n\t\t\tDest[i] := Copy(Text, 1, Pos(Separator, Text)-1);\n\t\t\tText := Copy(Text, Pos(Separator,Text) + Length(Separator), Length(Text));\n\t\t\ti := i + 1;\n\t\tend else begin\n\t\t\t Dest[i] := Text;\n\t\t\t Text := '';\n\t\tend;\n\tuntil Length(Text)=0;\nend;\n\n\nprocedure CurStepChanged(CurStep: TSetupStep);\nvar\n\ttaskname:\tString;\nbegin\n\ttaskname := ModPathName;\n\tif CurStep = ssPostInstall then\n\t\tif IsTaskSelected(taskname) then\n\t\t\tModPath();\nend;\n\nprocedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep);\nvar\n\taSelectedTasks:\tTArrayOfString;\n\ti:\t\t\t\tInteger;\n\ttaskname:\t\tString;\n\tregpath:\t\tString;\n\tregstring:\t\tString;\n\tappid:\t\t\tString;\nbegin\n\t// only run during actual uninstall\n\tif CurUninstallStep = usUninstall then begin\n\t\t// get list of selected tasks saved in registry at install time\n\t\tappid := '{#emit SetupSetting(\"AppId\")}';\n\t\tif appid = '' then appid := '{#emit SetupSetting(\"AppName\")}';\n\t\tregpath := ExpandConstant('Software\\Microsoft\\Windows\\CurrentVersion\\Uninstall\\'+appid+'_is1');\n\t\tRegQueryStringValue(HKLM, regpath, 'Inno Setup: Selected Tasks', regstring);\n\t\tif regstring = '' then RegQueryStringValue(HKCU, regpath, 'Inno Setup: Selected Tasks', regstring);\n\n\t\t// check each task; if matches modpath taskname, trigger patch removal\n\t\tif regstring <> '' then begin\n\t\t\ttaskname := ModPathName;\n\t\t\tMPExplode(aSelectedTasks, regstring, ',');\n\t\t\tif GetArrayLength(aSelectedTasks) > 0 then begin\n\t\t\t\tfor i := 0 to GetArrayLength(aSelectedTasks)-1 do begin\n\t\t\t\t\tif comparetext(aSelectedTasks[i], taskname) = 0 then\n\t\t\t\t\t\tModPath();\n\t\t\t\tend;\n\t\t\tend;\n\t\tend;\n\tend;\nend;\n\nfunction NeedRestart(): Boolean;\nvar\n\ttaskname:\tString;\nbegin\n\ttaskname := ModPathName;\n\tif IsTaskSelected(taskname) and not UsingWinNT() then begin\n\t\tResult := True;\n\tend else begin\n\t\tResult := False;\n\tend;\nend;\n"
  },
  {
    "path": "_setup-cpu.iss",
    "content": "#define MyAppName \"Rembg CPU\"\n#define MyAppVersion \"STABLE\"\n#define MyAppPublisher \"danielgatis\"\n#define MyAppURL \"https://github.com/danielgatis/rembg\"\n#define MyAppExeName \"rembg.exe\"\n#define MyAppId \"49AB7484-212F-4B31-A49F-533A480F3FD4\"\n\n[Setup]\nAppId={#MyAppId}\nAppName={#MyAppName}\nAppVersion={#MyAppVersion}\nAppPublisher={#MyAppPublisher}\nAppPublisherURL={#MyAppURL}\nAppSupportURL={#MyAppURL}\nAppUpdatesURL={#MyAppURL}\nDefaultDirName={autopf}\\Rembg\nDefaultGroupName=Rembg\nDisableProgramGroupPage=yes\nOutputBaseFilename=rembg-cli-cpu-installer\nCompression=lzma\nSolidCompression=yes\nWizardStyle=modern\nOutputDir=dist\nChangesEnvironment=yes\n\n[Languages]\nName: \"english\"; MessagesFile: \"compiler:Default.isl\"\n\n[Files]\nSource: \"{#SourcePath}dist\\rembg-cpu\\{#MyAppExeName}\"; DestDir: \"{app}\"; Flags: ignoreversion\nSource: \"{#SourcePath}dist\\rembg-cpu\\*\"; DestDir: \"{app}\"; Flags: ignoreversion recursesubdirs createallsubdirs\n\n[Tasks]\nName: modifypath; Description: \"Add to PATH variable\"\n\n[Icons]\nName: \"{group}\\Rembg\"; Filename: \"{app}\\{#MyAppExeName}\"\n\n[Code]\nconst\n    ModPathName = 'modifypath';\n    ModPathType = 'user';\n\nfunction ModPathDir(): TArrayOfString;\nbegin\n    setArrayLength(Result, 1)\n    Result[0] := ExpandConstant('{app}');\nend;\n#include \"_modpath.iss\"\n"
  },
  {
    "path": "_setup-gpu.iss",
    "content": "#define MyAppName \"Rembg GPU\"\n#define MyAppVersion \"STABLE\"\n#define MyAppPublisher \"danielgatis\"\n#define MyAppURL \"https://github.com/danielgatis/rembg\"\n#define MyAppExeName \"rembg.exe\"\n#define MyAppId \"49AB7484-212F-4B31-A49F-533A480F3FD4\"\n\n[Setup]\nAppId={#MyAppId}\nAppName={#MyAppName}\nAppVersion={#MyAppVersion}\nAppPublisher={#MyAppPublisher}\nAppPublisherURL={#MyAppURL}\nAppSupportURL={#MyAppURL}\nAppUpdatesURL={#MyAppURL}\nDefaultDirName={autopf}\\Rembg\nDefaultGroupName=Rembg\nDisableProgramGroupPage=yes\nOutputBaseFilename=rembg-cli-gpu-installer\nCompression=lzma\nSolidCompression=yes\nWizardStyle=modern\nOutputDir=dist\nChangesEnvironment=yes\n\n[Languages]\nName: \"english\"; MessagesFile: \"compiler:Default.isl\"\n\n[Files]\nSource: \"{#SourcePath}dist\\rembg-gpu\\{#MyAppExeName}\"; DestDir: \"{app}\"; Flags: ignoreversion\nSource: \"{#SourcePath}dist\\rembg-gpu\\*\"; DestDir: \"{app}\"; Flags: ignoreversion recursesubdirs createallsubdirs\n\n[Tasks]\nName: modifypath; Description: \"Add to PATH variable\"\n\n[Icons]\nName: \"{group}\\Rembg\"; Filename: \"{app}\\{#MyAppExeName}\"\n\n[Code]\nconst\n    ModPathName = 'modifypath';\n    ModPathType = 'user';\n\nfunction ModPathDir(): TArrayOfString;\nbegin\n    setArrayLength(Result, 1)\n    Result[0] := ExpandConstant('{app}');\nend;\n#include \"_modpath.iss\"\n"
  },
  {
    "path": "docker-compose.yml",
    "content": "---\n# You can set variables in .env file in root folder\n#\n# PUBLIC_PORT=7000:7000\n# REPLICAS_COUNT=1\n\nservices:\n  app:\n    build: .\n    command: [\"s\"]\n    deploy:\n      replicas: ${REPLICAS_COUNT:-1}\n    ports:\n      - ${PUBLIC_PORT:-7000:7000}\nversion: '3'\n"
  },
  {
    "path": "man/rembg.1",
    "content": ".TH REMBG 1 \"Januar 2026\" \"2.0.72\" \"User Commands\"\n.SH NAME\nrembg \\- tool to remove background from images\n.SH SYNOPSIS\n.B rembg\n[OPTIONS] COMMAND [ARGS]...\n.SH DESCRIPTION\n.B rembg\nis a tool to remove images background.\n.PP\nIt works as a command line interface and a library.\n.SH OPTIONS\n.TP\n.BR \\-\\-version\nShow the version and exit.\n.TP\n.BR \\-\\-help\nShow this message and exit.\n.SH SEE ALSO\nFull documentation at: <https://github.com/danielgatis/rembg>\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.poetry]\nname = \"rembg\"\nversion = \"0.0.0\"  # Managed by poetry-dynamic-versioning\ndescription = \"Remove image background\"\nauthors = [\"Daniel Gatis <danielgatis@gmail.com>\"]\nlicense = \"MIT\"\nreadme = \"README.md\"\nhomepage = \"https://github.com/danielgatis/rembg\"\nrepository = \"https://github.com/danielgatis/rembg\"\nkeywords = [\"remove\", \"background\", \"u2net\"]\nclassifiers = [\n    \"License :: OSI Approved :: MIT License\",\n    \"Topic :: Scientific/Engineering\",\n    \"Topic :: Scientific/Engineering :: Mathematics\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Software Development\",\n    \"Topic :: Software Development :: Libraries\",\n    \"Topic :: Software Development :: Libraries :: Python Modules\",\n    \"Programming Language :: Python\",\n    \"Programming Language :: Python :: 3 :: Only\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n    \"Programming Language :: Python :: 3.13\",\n]\npackages = [{include = \"rembg\"}]\n\n[tool.poetry.dependencies]\npython = \"^3.11\"\njsonschema = \"^4.25.1\"\nnumpy = \"^2.3.0\"\npillow = \"^12.1.0\"\npooch = \"^1.8.2\"\npymatting = \"^1.1.14\"\nscikit-image = \"^0.26.0\"\nscipy = \"^1.16.3\"\ntqdm = \"^4.67.1\"\n\n# CPU backend (optional)\nonnxruntime = {version = \"^1.23.2\", optional = true}\n\n# GPU backend (optional) - only available on Linux/Windows, not on macOS\nonnxruntime-gpu = {version = \"^1.23.2\", optional = true, markers = \"sys_platform != 'darwin'\"}\n\n# ROCm backend (optional) - only available on Linux (latest is 1.22.x)\nonnxruntime-rocm = {version = \"^1.22.0\", optional = true, markers = \"sys_platform == 'linux'\"}\n\n# CLI dependencies (optional)\naiohttp = {version = \"^3.13.2\", optional = true}\nasyncer = {version = \"^0.0.12\", optional = true}\nclick = {version = \"^8.3.1\", optional = true}\nfastapi = {version = \"^0.128.0\", optional = true}\nfiletype = {version = \"^1.2.0\", optional = true}\ngradio = {version = \"^6.2.0\", optional = true}\npython-multipart = {version = \"^0.0.21\", optional = true}\nsniffio = {version = \"^1.3.1\", optional = true}\nuvicorn = {version = \"^0.40.0\", optional = true}\nwatchdog = {version = \"^6.0.0\", optional = true}\n\n# Dev dependencies (optional, for pip install .[dev])\nbandit = {version = \"^1.9.2\", optional = true}\nblack = {version = \"^25.12.0\", optional = true}\nflake8 = {version = \"^7.3.0\", optional = true}\nimagehash = {version = \"^4.3.2\", optional = true}\nisort = {version = \"^7.0.0\", optional = true}\nmypy = {version = \"^1.19.1\", optional = true}\npytest = {version = \"^9.0.2\", optional = true}\n\n[tool.poetry.group.dev.dependencies]\nbandit = \"^1.9.2\"\nblack = \"^25.12.0\"\nflake8 = \"^7.3.0\"\nimagehash = \"^4.3.2\"\nisort = \"^7.0.0\"\nmypy = \"^1.19.1\"\npytest = \"^9.0.2\"\n\n[tool.poetry.extras]\ncpu = [\"onnxruntime\"]\ngpu = [\"onnxruntime-gpu\"]\nrocm = [\"onnxruntime-rocm\"]\ncli = [\"aiohttp\", \"asyncer\", \"click\", \"fastapi\", \"filetype\", \"gradio\", \"python-multipart\", \"sniffio\", \"uvicorn\", \"watchdog\"]\ndev = [\"bandit\", \"black\", \"flake8\", \"imagehash\", \"isort\", \"mypy\", \"pytest\"]\n\n[tool.poetry.scripts]\nrembg = \"rembg.cli:main\"\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\", \"poetry-dynamic-versioning>=1.0.0,<2.0.0\"]\nbuild-backend = \"poetry_dynamic_versioning.backend\"\n\n[tool.poetry-dynamic-versioning]\nenable = true\nvcs = \"git\"\nstyle = \"pep440\"\npattern = \"^v(?P<base>\\\\d+\\\\.\\\\d+\\\\.\\\\d+)\"\n\n[tool.poetry-dynamic-versioning.substitution]\nfiles = [\"rembg/__init__.py\"]\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\nfilterwarnings =\n    ignore::DeprecationWarning\n"
  },
  {
    "path": "rembg/__init__.py",
    "content": "try:\n    from importlib.metadata import PackageNotFoundError, version\n\n    try:\n        __version__ = version(\"rembg\")\n    except PackageNotFoundError:\n        __version__ = \"0.0.0\"  # Fallback for development\nexcept ImportError:\n    __version__ = \"0.0.0\"  # Fallback for older Python versions\n\nfrom .bg import remove\nfrom .session_factory import new_session\n"
  },
  {
    "path": "rembg/bg.py",
    "content": "import io\nimport sys\nfrom enum import Enum\nfrom typing import Any, List, Optional, Tuple, Union, cast\n\nimport numpy as np\n\ntry:\n    import onnxruntime as ort  # type: ignore[import-untyped]\nexcept ImportError:\n    print(\"No onnxruntime backend found.\")\n    print(\"Please install rembg with CPU or GPU support:\")\n    print()\n    print('    pip install \"rembg[cpu]\"  # for CPU')\n    print('    pip install \"rembg[gpu]\"  # for NVIDIA/CUDA GPU')\n    print()\n    print(\n        \"For more information, see: https://github.com/danielgatis/rembg#installation\"\n    )\n    sys.exit(1)\n\nfrom PIL import Image, ImageOps\nfrom PIL.Image import Image as PILImage\nfrom pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf\nfrom pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml\nfrom pymatting.util.util import stack_images\nfrom scipy.ndimage import binary_erosion, gaussian_filter\nfrom skimage.morphology import disk, opening\n\nfrom .session_factory import new_session\nfrom .sessions import sessions, sessions_names\nfrom .sessions.base import BaseSession\n\nort.set_default_logger_severity(3)\n\nkernel = disk(1)\n\n\nclass ReturnType(Enum):\n    BYTES = 0\n    PILLOW = 1\n    NDARRAY = 2\n\n\ndef alpha_matting_cutout(\n    img: PILImage,\n    mask: PILImage,\n    foreground_threshold: int,\n    background_threshold: int,\n    erode_structure_size: int,\n) -> PILImage:\n    \"\"\"\n    Perform alpha matting on an image using a given mask and threshold values.\n\n    This function takes a PIL image `img` and a PIL image `mask` as input, along with\n    the `foreground_threshold` and `background_threshold` values used to determine\n    foreground and background pixels. The `erode_structure_size` parameter specifies\n    the size of the erosion structure to be applied to the mask.\n\n    The function returns a PIL image representing the cutout of the foreground object\n    from the original image.\n    \"\"\"\n    if img.mode == \"RGBA\" or img.mode == \"CMYK\":\n        img = img.convert(\"RGB\")\n\n    img_array = np.asarray(img)\n    mask_array = np.asarray(mask)\n\n    is_foreground = mask_array > foreground_threshold\n    is_background = mask_array < background_threshold\n\n    structure = None\n    if erode_structure_size > 0:\n        structure = np.ones(\n            (erode_structure_size, erode_structure_size), dtype=np.uint8\n        )\n\n    is_foreground = binary_erosion(is_foreground, structure=structure)\n    is_background = binary_erosion(is_background, structure=structure, border_value=1)\n\n    trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)\n    trimap[is_foreground] = 255\n    trimap[is_background] = 0\n\n    img_normalized = img_array / 255.0\n    trimap_normalized = trimap / 255.0\n\n    alpha = estimate_alpha_cf(img_normalized, trimap_normalized)\n    foreground = estimate_foreground_ml(img_normalized, alpha)\n    cutout = stack_images(foreground, alpha)\n\n    cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)\n    cutout = Image.fromarray(cutout)\n\n    return cutout\n\n\ndef naive_cutout(img: PILImage, mask: PILImage) -> PILImage:\n    \"\"\"\n    Perform a simple cutout operation on an image using a mask.\n\n    This function takes a PIL image `img` and a PIL image `mask` as input.\n    It uses the mask to create a new image where the pixels from `img` are\n    cut out based on the mask.\n\n    The function returns a PIL image representing the cutout of the original\n    image using the mask.\n    \"\"\"\n    empty = Image.new(\"RGBA\", (img.size), 0)\n    cutout = Image.composite(img, empty, mask)\n    return cutout\n\n\ndef putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage:\n    \"\"\"\n    Apply the specified mask to the image as an alpha cutout.\n\n    Args:\n        img (PILImage): The image to be modified.\n        mask (PILImage): The mask to be applied.\n\n    Returns:\n        PILImage: The modified image with the alpha cutout applied.\n    \"\"\"\n    img.putalpha(mask)\n    return img\n\n\ndef get_concat_v_multi(imgs: List[PILImage]) -> PILImage:\n    \"\"\"\n    Concatenate multiple images vertically.\n\n    Args:\n        imgs (List[PILImage]): The list of images to be concatenated.\n\n    Returns:\n        PILImage: The concatenated image.\n    \"\"\"\n    pivot = imgs.pop(0)\n    for im in imgs:\n        pivot = get_concat_v(pivot, im)\n    return pivot\n\n\ndef get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:\n    \"\"\"\n    Concatenate two images vertically.\n\n    Args:\n        img1 (PILImage): The first image.\n        img2 (PILImage): The second image to be concatenated below the first image.\n\n    Returns:\n        PILImage: The concatenated image.\n    \"\"\"\n    dst = Image.new(\"RGBA\", (img1.width, img1.height + img2.height))\n    dst.paste(img1, (0, 0))\n    dst.paste(img2, (0, img1.height))\n    return dst\n\n\ndef post_process(mask: np.ndarray) -> np.ndarray:\n    \"\"\"\n    Post Process the mask for a smooth boundary by applying Morphological Operations\n    Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757\n    args:\n        mask: Binary Numpy Mask\n    \"\"\"\n    mask = opening(mask, kernel)\n    mask = gaussian_filter(mask.astype(np.float64), sigma=2)\n    mask = np.where(mask < 127, 0, 255).astype(np.uint8)\n    return mask\n\n\ndef apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:\n    \"\"\"\n    Apply the specified background color to the image.\n\n    Args:\n        img (PILImage): The image to be modified.\n        color (Tuple[int, int, int, int]): The RGBA color to be applied.\n\n    Returns:\n        PILImage: The modified image with the background color applied.\n    \"\"\"\n    background = Image.new(\"RGBA\", img.size, tuple(color))\n    colored_image = Image.alpha_composite(background, img)\n\n    return colored_image\n\n\ndef fix_image_orientation(img: PILImage) -> PILImage:\n    \"\"\"\n    Fix the orientation of the image based on its EXIF data.\n\n    Args:\n        img (PILImage): The image to be fixed.\n\n    Returns:\n        PILImage: The fixed image.\n    \"\"\"\n    return cast(PILImage, ImageOps.exif_transpose(img))\n\n\ndef download_models(models: tuple[str, ...]) -> None:\n    \"\"\"\n    Download models for image processing.\n    \"\"\"\n    if len(models) == 0:\n        print(\"No models specified, downloading all models\")\n        models = tuple(sessions_names)\n\n    for model in models:\n        session = sessions.get(model)\n        if session is None:\n            print(f\"Error: no model found: {model}\")\n            sys.exit(1)\n        else:\n            print(f\"Downloading model: {model}\")\n            try:\n                session.download_models()\n            except Exception as e:\n                print(f\"Error downloading model: {e}\")\n\n\ndef remove(\n    data: Union[bytes, PILImage, np.ndarray],\n    alpha_matting: bool = False,\n    alpha_matting_foreground_threshold: int = 240,\n    alpha_matting_background_threshold: int = 10,\n    alpha_matting_erode_size: int = 10,\n    session: Optional[BaseSession] = None,\n    only_mask: bool = False,\n    post_process_mask: bool = False,\n    bgcolor: Optional[Tuple[int, int, int, int]] = None,\n    force_return_bytes: bool = False,\n    *args: Optional[Any],\n    **kwargs: Optional[Any],\n) -> Union[bytes, PILImage, np.ndarray]:\n    \"\"\"\n    Remove the background from an input image.\n\n    This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter or as python bytes if force_return_bytes is true.\n\n    Parameters:\n        data (Union[bytes, PILImage, np.ndarray]): The input image data.\n        alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.\n        alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.\n        alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.\n        alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.\n        session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.\n        only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.\n        post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.\n        bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.\n        force_return_bytes (bool, optional): Flag indicating whether to return the cutout image as bytes. Defaults to False.\n        *args (Optional[Any]): Additional positional arguments.\n        **kwargs (Optional[Any]): Additional keyword arguments.\n\n    Returns:\n        Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed.\n    \"\"\"\n    if isinstance(data, bytes) or force_return_bytes:\n        return_type = ReturnType.BYTES\n        img = cast(PILImage, Image.open(io.BytesIO(cast(bytes, data))))\n    elif isinstance(data, PILImage):\n        return_type = ReturnType.PILLOW\n        img = cast(PILImage, data)\n    elif isinstance(data, np.ndarray):\n        return_type = ReturnType.NDARRAY\n        img = cast(PILImage, Image.fromarray(data))\n    else:\n        raise ValueError(\n            \"Input type {} is not supported. Try using force_return_bytes=True to force python bytes output\".format(\n                type(data)\n            )\n        )\n\n    putalpha = kwargs.pop(\"putalpha\", False)\n\n    # Fix image orientation\n    img = fix_image_orientation(img)\n\n    if session is None:\n        session = new_session(\"u2net\", *args, **kwargs)\n\n    masks = session.predict(img, *args, **kwargs)\n    cutouts = []\n\n    for mask in masks:\n        if post_process_mask:\n            mask = Image.fromarray(post_process(np.array(mask)))\n\n        if only_mask:\n            cutout = mask\n\n        elif alpha_matting:\n            try:\n                cutout = alpha_matting_cutout(\n                    img,\n                    mask,\n                    alpha_matting_foreground_threshold,\n                    alpha_matting_background_threshold,\n                    alpha_matting_erode_size,\n                )\n            except ValueError:\n                if putalpha:\n                    cutout = putalpha_cutout(img, mask)\n                else:\n                    cutout = naive_cutout(img, mask)\n        else:\n            if putalpha:\n                cutout = putalpha_cutout(img, mask)\n            else:\n                cutout = naive_cutout(img, mask)\n\n        cutouts.append(cutout)\n\n    cutout = img\n    if len(cutouts) > 0:\n        cutout = get_concat_v_multi(cutouts)\n\n    if bgcolor is not None and not only_mask:\n        cutout = apply_background_color(cutout, bgcolor)\n\n    if ReturnType.PILLOW == return_type:\n        return cutout\n\n    if ReturnType.NDARRAY == return_type:\n        return np.asarray(cutout)\n\n    bio = io.BytesIO()\n    cutout.save(bio, \"PNG\")\n    bio.seek(0)\n\n    return bio.read()\n"
  },
  {
    "path": "rembg/cli.py",
    "content": "import sys\n\n# Fast path for --version (avoid importing heavy dependencies)\nif len(sys.argv) == 2 and sys.argv[1] in (\"--version\", \"-V\"):\n    from importlib.metadata import version\n\n    print(f\"rembg, version {version('rembg')}\")\n    sys.exit(0)\n\ntry:\n    import click\nexcept ImportError:\n    print(\"The CLI dependencies are not installed.\")\n    print(\"Please install rembg with CLI support:\")\n    print()\n    print('    pip install \"rembg[cpu,cli]\"  # for CPU')\n    print('    pip install \"rembg[gpu,cli]\"  # for NVIDIA/CUDA GPU')\n    print()\n    print(\n        \"For more information, see: https://github.com/danielgatis/rembg#installation\"\n    )\n    sys.exit(1)\n\nfrom . import __version__\nfrom .commands import command_functions\n\n\n@click.group()\n@click.version_option(version=__version__)\ndef main() -> None:\n    pass\n\n\nfor command in command_functions:\n    main.add_command(command)\n"
  },
  {
    "path": "rembg/commands/__init__.py",
    "content": "command_functions = []\n\nfrom .b_command import b_command\nfrom .d_command import d_command\nfrom .i_command import i_command\nfrom .p_command import p_command\nfrom .s_command import s_command\n\ncommand_functions.append(b_command)\ncommand_functions.append(d_command)\ncommand_functions.append(i_command)\ncommand_functions.append(p_command)\ncommand_functions.append(s_command)\n"
  },
  {
    "path": "rembg/commands/b_command.py",
    "content": "import asyncio\nimport io\nimport json\nimport os\nimport sys\nfrom typing import IO\n\nimport click\nimport PIL\n\nfrom ..bg import remove\nfrom ..session_factory import new_session\nfrom ..sessions import sessions_names\n\n\n@click.command(  # type: ignore\n    name=\"b\",\n    help=\"for a byte stream as input\",\n)\n@click.option(\n    \"-m\",\n    \"--model\",\n    default=\"u2net\",\n    type=click.Choice(sessions_names),\n    show_default=True,\n    show_choices=True,\n    help=\"model name\",\n)\n@click.option(\n    \"-a\",\n    \"--alpha-matting\",\n    is_flag=True,\n    show_default=True,\n    help=\"use alpha matting\",\n)\n@click.option(\n    \"-af\",\n    \"--alpha-matting-foreground-threshold\",\n    default=240,\n    type=int,\n    show_default=True,\n    help=\"trimap fg threshold\",\n)\n@click.option(\n    \"-ab\",\n    \"--alpha-matting-background-threshold\",\n    default=10,\n    type=int,\n    show_default=True,\n    help=\"trimap bg threshold\",\n)\n@click.option(\n    \"-ae\",\n    \"--alpha-matting-erode-size\",\n    default=10,\n    type=int,\n    show_default=True,\n    help=\"erode size\",\n)\n@click.option(\n    \"-om\",\n    \"--only-mask\",\n    is_flag=True,\n    show_default=True,\n    help=\"output only the mask\",\n)\n@click.option(\n    \"-ppm\",\n    \"--post-process-mask\",\n    is_flag=True,\n    show_default=True,\n    help=\"post process the mask\",\n)\n@click.option(\n    \"-bgc\",\n    \"--bgcolor\",\n    default=(0, 0, 0, 0),\n    type=(int, int, int, int),\n    nargs=4,\n    help=\"Background color (R G B A) to replace the removed background with\",\n)\n@click.option(\"-x\", \"--extras\", type=str)\n@click.option(\n    \"-o\",\n    \"--output_specifier\",\n    type=str,\n    help=\"printf-style specifier for output filenames (e.g. 'output-%d.png'))\",\n)\n@click.argument(\n    \"image_width\",\n    type=int,\n)\n@click.argument(\n    \"image_height\",\n    type=int,\n)\ndef b_command(\n    model: str,\n    extras: str,\n    image_width: int,\n    image_height: int,\n    output_specifier: str,\n    **kwargs\n) -> None:\n    \"\"\"\n    Command-line interface for processing images by removing the background using a specified model and generating a mask.\n\n    This CLI command takes several options and arguments to configure the background removal process and save the processed images.\n\n    Parameters:\n        model (str): The name of the model to use for background removal.\n        extras (str): Additional options in JSON format that can be passed to customize the background removal process.\n        image_width (int): The width of the input images in pixels.\n        image_height (int): The height of the input images in pixels.\n        output_specifier (str): A printf-style specifier for the output filenames. If specified, the processed images will be saved to the specified output directory with filenames generated using the specifier.\n        **kwargs: Additional keyword arguments that can be used to customize the background removal process.\n\n    Returns:\n        None\n    \"\"\"\n    if extras:\n        try:\n            kwargs.update(json.loads(extras))\n        except Exception:\n            raise click.BadParameter(\"extras must be a valid JSON string\")\n\n    session = new_session(model, **kwargs)\n    bytes_per_img = image_width * image_height * 3\n\n    if output_specifier:\n        output_dir = os.path.dirname(\n            os.path.abspath(os.path.expanduser(output_specifier))\n        )\n\n        if not os.path.isdir(output_dir):\n            os.makedirs(output_dir, exist_ok=True)\n\n    def img_to_byte_array(img: PIL.Image.Image) -> bytes:\n        buff = io.BytesIO()\n        img.save(buff, format=\"PNG\")\n        return buff.getvalue()\n\n    async def connect_stdin_stdout():\n        loop = asyncio.get_event_loop()\n        reader = asyncio.StreamReader()\n        protocol = asyncio.StreamReaderProtocol(reader)\n\n        await loop.connect_read_pipe(lambda: protocol, sys.stdin)\n        w_transport, w_protocol = await loop.connect_write_pipe(\n            asyncio.streams.FlowControlMixin, sys.stdout\n        )\n\n        writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)\n        return reader, writer\n\n    async def main():\n        reader, writer = await connect_stdin_stdout()\n\n        idx = 0\n        while True:\n            try:\n                img_bytes = await reader.readexactly(bytes_per_img)\n                if not img_bytes:\n                    break\n\n                img = PIL.Image.frombytes(\"RGB\", (image_width, image_height), img_bytes)\n                output = remove(img, session=session, **kwargs)\n\n                if output_specifier:\n                    output.save((output_specifier % idx), format=\"PNG\")\n                else:\n                    writer.write(img_to_byte_array(output))\n\n                idx += 1\n            except asyncio.IncompleteReadError:\n                break\n\n    asyncio.run(main())\n"
  },
  {
    "path": "rembg/commands/d_command.py",
    "content": "import click\n\nfrom ..bg import download_models\n\n\n@click.command(  # type: ignore\n    name=\"d\",\n    help=\"download models\",\n)\n@click.argument(\"models\", nargs=-1)\ndef d_command(models: tuple[str, ...]) -> None:\n    \"\"\"\n    Download models\n    \"\"\"\n    download_models(models)\n"
  },
  {
    "path": "rembg/commands/i_command.py",
    "content": "import json\nimport sys\nfrom typing import IO\n\nimport click\n\nfrom ..bg import remove\nfrom ..session_factory import new_session\nfrom ..sessions import sessions_names\n\n\n@click.command(  # type: ignore\n    name=\"i\",\n    help=\"for a file as input\",\n)\n@click.option(\n    \"-m\",\n    \"--model\",\n    default=\"u2net\",\n    type=click.Choice(sessions_names),\n    show_default=True,\n    show_choices=True,\n    help=\"model name\",\n)\n@click.option(\n    \"-a\",\n    \"--alpha-matting\",\n    is_flag=True,\n    show_default=True,\n    help=\"use alpha matting\",\n)\n@click.option(\n    \"-af\",\n    \"--alpha-matting-foreground-threshold\",\n    default=240,\n    type=int,\n    show_default=True,\n    help=\"trimap fg threshold\",\n)\n@click.option(\n    \"-ab\",\n    \"--alpha-matting-background-threshold\",\n    default=10,\n    type=int,\n    show_default=True,\n    help=\"trimap bg threshold\",\n)\n@click.option(\n    \"-ae\",\n    \"--alpha-matting-erode-size\",\n    default=10,\n    type=int,\n    show_default=True,\n    help=\"erode size\",\n)\n@click.option(\n    \"-om\",\n    \"--only-mask\",\n    is_flag=True,\n    show_default=True,\n    help=\"output only the mask\",\n)\n@click.option(\n    \"-ppm\",\n    \"--post-process-mask\",\n    is_flag=True,\n    show_default=True,\n    help=\"post process the mask\",\n)\n@click.option(\n    \"-bgc\",\n    \"--bgcolor\",\n    default=(0, 0, 0, 0),\n    type=(int, int, int, int),\n    nargs=4,\n    help=\"Background color (R G B A) to replace the removed background with\",\n)\n@click.option(\"-x\", \"--extras\", type=str)\n@click.argument(\n    \"input\", default=(None if sys.stdin.isatty() else \"-\"), type=click.File(\"rb\")\n)\n@click.argument(\n    \"output\",\n    default=(None if sys.stdin.isatty() else \"-\"),\n    type=click.File(\"wb\", lazy=True),\n)\ndef i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:\n    \"\"\"\n    Click command line interface function to process an input file based on the provided options.\n\n    This function is the entry point for the CLI program. It reads an input file, applies image processing operations based on the provided options, and writes the output to a file.\n\n    Parameters:\n        model (str): The name of the model to use for image processing.\n        extras (str): Additional options in JSON format.\n        input: The input file to process.\n        output: The output file to write the processed image to.\n        **kwargs: Additional keyword arguments corresponding to the command line options.\n\n    Returns:\n        None\n    \"\"\"\n    try:\n        kwargs.update(json.loads(extras))\n    except Exception:\n        pass\n\n    output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs))\n"
  },
  {
    "path": "rembg/commands/p_command.py",
    "content": "import json\nimport pathlib\nimport time\nfrom typing import cast\n\nimport click\nimport filetype\nfrom tqdm import tqdm\nfrom watchdog.events import FileSystemEvent, FileSystemEventHandler\nfrom watchdog.observers import Observer\n\nfrom ..bg import remove\nfrom ..session_factory import new_session\nfrom ..sessions import sessions_names\n\n\n@click.command(  # type: ignore\n    name=\"p\",\n    help=\"for a folder as input\",\n)\n@click.option(\n    \"-m\",\n    \"--model\",\n    default=\"u2net\",\n    type=click.Choice(sessions_names),\n    show_default=True,\n    show_choices=True,\n    help=\"model name\",\n)\n@click.option(\n    \"-a\",\n    \"--alpha-matting\",\n    is_flag=True,\n    show_default=True,\n    help=\"use alpha matting\",\n)\n@click.option(\n    \"-af\",\n    \"--alpha-matting-foreground-threshold\",\n    default=240,\n    type=int,\n    show_default=True,\n    help=\"trimap fg threshold\",\n)\n@click.option(\n    \"-ab\",\n    \"--alpha-matting-background-threshold\",\n    default=10,\n    type=int,\n    show_default=True,\n    help=\"trimap bg threshold\",\n)\n@click.option(\n    \"-ae\",\n    \"--alpha-matting-erode-size\",\n    default=10,\n    type=int,\n    show_default=True,\n    help=\"erode size\",\n)\n@click.option(\n    \"-om\",\n    \"--only-mask\",\n    is_flag=True,\n    show_default=True,\n    help=\"output only the mask\",\n)\n@click.option(\n    \"-ppm\",\n    \"--post-process-mask\",\n    is_flag=True,\n    show_default=True,\n    help=\"post process the mask\",\n)\n@click.option(\n    \"-w\",\n    \"--watch\",\n    default=False,\n    is_flag=True,\n    show_default=True,\n    help=\"watches a folder for changes\",\n)\n@click.option(\n    \"-d\",\n    \"--delete_input\",\n    default=False,\n    is_flag=True,\n    show_default=True,\n    help=\"delete input file after processing\",\n)\n@click.option(\n    \"-bgc\",\n    \"--bgcolor\",\n    default=(0, 0, 0, 0),\n    type=(int, int, int, int),\n    nargs=4,\n    help=\"Background color (R G B A) to replace the removed background with\",\n)\n@click.option(\"-x\", \"--extras\", type=str)\n@click.argument(\n    \"input\",\n    type=click.Path(\n        exists=True,\n        path_type=pathlib.Path,\n        file_okay=False,\n        dir_okay=True,\n        readable=True,\n    ),\n)\n@click.argument(\n    \"output\",\n    type=click.Path(\n        exists=False,\n        path_type=pathlib.Path,\n        file_okay=False,\n        dir_okay=True,\n        writable=True,\n    ),\n)\ndef p_command(\n    model: str,\n    extras: str,\n    input: pathlib.Path,\n    output: pathlib.Path,\n    watch: bool,\n    delete_input: bool,\n    **kwargs,\n) -> None:\n    \"\"\"\n    Command-line interface (CLI) program for performing background removal on images in a folder.\n\n    This program takes a folder as input and uses a specified model to remove the background from the images in the folder.\n    It provides various options for configuration, such as choosing the model, enabling alpha matting, setting trimap thresholds, erode size, etc.\n    Additional options include outputting only the mask and post-processing the mask.\n    The program can also watch the input folder for changes and automatically process new images.\n    The resulting images with the background removed are saved in the specified output folder.\n\n    Parameters:\n        model (str): The name of the model to use for background removal.\n        extras (str): Additional options in JSON format.\n        input (pathlib.Path): The path to the input folder.\n        output (pathlib.Path): The path to the output folder.\n        watch (bool): Whether to watch the input folder for changes.\n        delete_input (bool): Whether to delete the input file after processing.\n        **kwargs: Additional keyword arguments.\n\n    Returns:\n        None\n    \"\"\"\n    try:\n        kwargs.update(json.loads(extras))\n    except Exception:\n        pass\n\n    session = new_session(model, **kwargs)\n\n    def process(each_input: pathlib.Path) -> None:\n        try:\n            mimetype = filetype.guess(each_input)\n            if mimetype is None:\n                return\n            if mimetype.mime.find(\"image\") < 0:\n                return\n\n            each_output = (output / each_input.name).with_suffix(\".png\")\n            each_output.parents[0].mkdir(parents=True, exist_ok=True)\n\n            if not each_output.exists():\n                each_output.write_bytes(\n                    cast(\n                        bytes,\n                        remove(each_input.read_bytes(), session=session, **kwargs),\n                    )\n                )\n\n                if watch:\n                    print(\n                        f\"processed: {each_input.absolute()} -> {each_output.absolute()}\"\n                    )\n\n            if delete_input:\n                each_input.unlink()\n\n        except Exception as e:\n            print(e)\n\n    inputs = list(input.glob(\"**/*\"))\n    inputs_tqdm = inputs if watch else tqdm(inputs)\n\n    for each_input in inputs_tqdm:\n        if not each_input.is_dir():\n            process(each_input)\n\n    if watch:\n        should_watch = True\n        observer = Observer()\n\n        class EventHandler(FileSystemEventHandler):\n            def on_any_event(self, event: FileSystemEvent) -> None:\n                src_path = cast(str, event.src_path)\n                if (\n                    not (\n                        event.is_directory or event.event_type in [\"deleted\", \"closed\"]\n                    )\n                    and pathlib.Path(src_path).exists()\n                ):\n                    if src_path.endswith(\"stop.txt\"):\n                        nonlocal should_watch\n                        should_watch = False\n                        pathlib.Path(src_path).unlink()\n                        return\n\n                    process(pathlib.Path(src_path))\n\n        event_handler = EventHandler()\n        observer.schedule(event_handler, str(input), recursive=False)\n        observer.start()\n\n        try:\n            while should_watch:\n                time.sleep(1)\n\n        finally:\n            observer.stop()\n            observer.join()\n"
  },
  {
    "path": "rembg/commands/s_command.py",
    "content": "import json\nimport os\nimport webbrowser\nfrom typing import Optional, Tuple, cast\n\nimport aiohttp\nimport click\nimport gradio as gr\nimport uvicorn\nfrom asyncer import asyncify\nfrom fastapi import Depends, FastAPI, File, Form, Query\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom starlette.responses import Response\n\nfrom .. import __version__\nfrom ..bg import remove\nfrom ..session_factory import new_session\nfrom ..sessions import sessions_names\nfrom ..sessions.base import BaseSession\n\n\n@click.command(  # type: ignore\n    name=\"s\",\n    help=\"for a http server\",\n)\n@click.option(\n    \"-p\",\n    \"--port\",\n    default=7000,\n    type=int,\n    show_default=True,\n    help=\"port\",\n)\n@click.option(\n    \"-h\",\n    \"--host\",\n    default=\"0.0.0.0\",\n    type=str,\n    show_default=True,\n    help=\"host\",\n)\n@click.option(\n    \"-l\",\n    \"--log_level\",\n    default=\"info\",\n    type=str,\n    show_default=True,\n    help=\"log level\",\n)\n@click.option(\n    \"-t\",\n    \"--threads\",\n    default=None,\n    type=int,\n    show_default=True,\n    help=\"number of worker threads\",\n)\ndef s_command(port: int, host: str, log_level: str, threads: int) -> None:\n    \"\"\"\n    Command-line interface for running the FastAPI web server.\n\n    This function starts the FastAPI web server with the specified port and log level.\n    If the number of worker threads is specified, it sets the thread limiter accordingly.\n    \"\"\"\n    sessions: dict[str, BaseSession] = {}\n    tags_metadata = [\n        {\n            \"name\": \"Background Removal\",\n            \"description\": \"Endpoints that perform background removal with different image sources.\",\n            \"externalDocs\": {\n                \"description\": \"GitHub Source\",\n                \"url\": \"https://github.com/danielgatis/rembg\",\n            },\n        },\n    ]\n    app = FastAPI(\n        title=\"Rembg\",\n        description=\"Rembg is a tool to remove images background. That is it.\",\n        version=__version__,\n        contact={\n            \"name\": \"Daniel Gatis\",\n            \"url\": \"https://github.com/danielgatis\",\n            \"email\": \"danielgatis@gmail.com\",\n        },\n        license_info={\n            \"name\": \"MIT License\",\n            \"url\": \"https://github.com/danielgatis/rembg/blob/main/LICENSE.txt\",\n        },\n        openapi_tags=tags_metadata,\n        docs_url=\"/api\",\n    )\n\n    app.add_middleware(\n        CORSMiddleware,\n        allow_credentials=True,\n        allow_origins=[\"*\"],\n        allow_methods=[\"*\"],\n        allow_headers=[\"*\"],\n    )\n\n    class CommonQueryParams:\n        def __init__(\n            self,\n            model: str = Query(\n                description=\"Model to use when processing image\",\n                regex=r\"(\" + \"|\".join(sessions_names) + \")\",\n                default=\"u2net\",\n            ),\n            a: bool = Query(default=False, description=\"Enable Alpha Matting\"),\n            af: int = Query(\n                default=240,\n                ge=0,\n                le=255,\n                description=\"Alpha Matting (Foreground Threshold)\",\n            ),\n            ab: int = Query(\n                default=10,\n                ge=0,\n                le=255,\n                description=\"Alpha Matting (Background Threshold)\",\n            ),\n            ae: int = Query(\n                default=10, ge=0, description=\"Alpha Matting (Erode Structure Size)\"\n            ),\n            om: bool = Query(default=False, description=\"Only Mask\"),\n            ppm: bool = Query(default=False, description=\"Post Process Mask\"),\n            bgc: Optional[str] = Query(default=None, description=\"Background Color\"),\n            extras: Optional[str] = Query(\n                default=None, description=\"Extra parameters as JSON\"\n            ),\n        ):\n            self.model = model\n            self.a = a\n            self.af = af\n            self.ab = ab\n            self.ae = ae\n            self.om = om\n            self.ppm = ppm\n            self.extras = extras\n            self.bgc = (\n                cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(\",\"))))\n                if bgc\n                else None\n            )\n\n    class CommonQueryPostParams:\n        def __init__(\n            self,\n            model: str = Form(\n                description=\"Model to use when processing image\",\n                regex=r\"(\" + \"|\".join(sessions_names) + \")\",\n                default=\"u2net\",\n            ),\n            a: bool = Form(default=False, description=\"Enable Alpha Matting\"),\n            af: int = Form(\n                default=240,\n                ge=0,\n                le=255,\n                description=\"Alpha Matting (Foreground Threshold)\",\n            ),\n            ab: int = Form(\n                default=10,\n                ge=0,\n                le=255,\n                description=\"Alpha Matting (Background Threshold)\",\n            ),\n            ae: int = Form(\n                default=10, ge=0, description=\"Alpha Matting (Erode Structure Size)\"\n            ),\n            om: bool = Form(default=False, description=\"Only Mask\"),\n            ppm: bool = Form(default=False, description=\"Post Process Mask\"),\n            bgc: Optional[str] = Query(default=None, description=\"Background Color\"),\n            extras: Optional[str] = Query(\n                default=None, description=\"Extra parameters as JSON\"\n            ),\n        ):\n            self.model = model\n            self.a = a\n            self.af = af\n            self.ab = ab\n            self.ae = ae\n            self.om = om\n            self.ppm = ppm\n            self.extras = extras\n            self.bgc = (\n                cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(\",\"))))\n                if bgc\n                else None\n            )\n\n    def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:\n        kwargs = {}\n\n        if commons.extras:\n            try:\n                kwargs.update(json.loads(commons.extras))\n            except Exception:\n                pass\n\n        session = sessions.get(commons.model)\n        if session is None:\n            session = new_session(commons.model, **kwargs)\n            sessions[commons.model] = session\n\n        return Response(\n            remove(\n                content,\n                session=session,\n                alpha_matting=commons.a,\n                alpha_matting_foreground_threshold=commons.af,\n                alpha_matting_background_threshold=commons.ab,\n                alpha_matting_erode_size=commons.ae,\n                only_mask=commons.om,\n                post_process_mask=commons.ppm,\n                bgcolor=commons.bgc,\n                **kwargs,\n            ),\n            media_type=\"image/png\",\n        )\n\n    @app.on_event(\"startup\")\n    def startup():\n        try:\n            webbrowser.open(f\"http://localhost:{port}\")\n        except Exception:\n            pass\n\n        if threads is not None:\n            from anyio import CapacityLimiter\n            from anyio.lowlevel import RunVar\n\n            RunVar(\"_default_thread_limiter\").set(CapacityLimiter(threads))\n\n    @app.get(\n        path=\"/api/remove\",\n        tags=[\"Background Removal\"],\n        summary=\"Remove from URL\",\n        description=\"Removes the background from an image obtained by retrieving an URL.\",\n    )\n    async def get_index(\n        url: str = Query(\n            default=..., description=\"URL of the image that has to be processed.\"\n        ),\n        commons: CommonQueryParams = Depends(),\n    ):\n        async with aiohttp.ClientSession() as session:\n            async with session.get(url) as response:\n                file = await response.read()\n                return await asyncify(im_without_bg)(file, commons)\n\n    @app.post(\n        path=\"/api/remove\",\n        tags=[\"Background Removal\"],\n        summary=\"Remove from Stream\",\n        description=\"Removes the background from an image sent within the request itself.\",\n    )\n    async def post_index(\n        file: bytes = File(\n            default=...,\n            description=\"Image file (byte stream) that has to be processed.\",\n        ),\n        commons: CommonQueryPostParams = Depends(),\n    ):\n        return await asyncify(im_without_bg)(file, commons)  # type: ignore\n\n    def gr_app(app):\n        def inference(input_path, model, *args):\n            output_path = \"output.png\"\n            a, af, ab, ae, om, ppm, cmd_args = args\n\n            kwargs = {\n                \"alpha_matting\": a,\n                \"alpha_matting_foreground_threshold\": af,\n                \"alpha_matting_background_threshold\": ab,\n                \"alpha_matting_erode_size\": ae,\n                \"only_mask\": om,\n                \"post_process_mask\": ppm,\n            }\n\n            if cmd_args:\n                kwargs.update(json.loads(cmd_args))\n\n            session = sessions.get(model)\n            if session is None:\n                session = new_session(model, **kwargs)\n                sessions[model] = session\n            kwargs[\"session\"] = session\n\n            with open(input_path, \"rb\") as i:\n                with open(output_path, \"wb\") as o:\n                    input = i.read()\n                    output = remove(input, **kwargs)\n                    o.write(output)\n            return os.path.join(output_path)\n\n        interface = gr.Interface(\n            inference,\n            [\n                gr.components.Image(type=\"filepath\", label=\"Input\"),\n                gr.components.Dropdown(sessions_names, value=\"u2net\", label=\"Models\"),\n                gr.components.Checkbox(value=True, label=\"Alpha matting\"),\n                gr.components.Slider(\n                    value=240, minimum=0, maximum=255, label=\"Foreground threshold\"\n                ),\n                gr.components.Slider(\n                    value=10, minimum=0, maximum=255, label=\"Background threshold\"\n                ),\n                gr.components.Slider(\n                    value=40, minimum=0, maximum=255, label=\"Erosion size\"\n                ),\n                gr.components.Checkbox(value=False, label=\"Only mask\"),\n                gr.components.Checkbox(value=True, label=\"Post process mask\"),\n                gr.components.Textbox(label=\"Arguments\"),\n            ],\n            gr.components.Image(type=\"filepath\", label=\"Output\"),\n            concurrency_limit=3,\n            analytics_enabled=False,\n        )\n\n        app = gr.mount_gradio_app(app, interface, path=\"/\")\n        return app\n\n    print(\n        f\"To access the API documentation, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}/api\"\n    )\n    print(\n        f\"To access the UI, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}\"\n    )\n\n    uvicorn.run(gr_app(app), host=host, port=port, log_level=log_level)\n"
  },
  {
    "path": "rembg/session_factory.py",
    "content": "import os\nfrom typing import Optional, Type\n\nimport onnxruntime as ort\n\nfrom .sessions import sessions_class\nfrom .sessions.base import BaseSession\nfrom .sessions.u2net import U2netSession\n\n\ndef new_session(model_name: str = \"u2net\", *args, **kwargs) -> BaseSession:\n    \"\"\"\n    Create a new session object based on the specified model name.\n\n    This function searches for the session class based on the model name in the 'sessions_class' list.\n    It then creates an instance of the session class with the provided arguments.\n    The 'sess_opts' object is created using the 'ort.SessionOptions()' constructor.\n    If the 'OMP_NUM_THREADS' environment variable is set, the 'inter_op_num_threads' option of 'sess_opts' is set to its value.\n\n    Parameters:\n        model_name (str): The name of the model.\n        *args: Additional positional arguments.\n        **kwargs: Additional keyword arguments.\n\n    Raises:\n        ValueError: If no session class with the given `model_name` is found.\n\n    Returns:\n        BaseSession: The created session object.\n    \"\"\"\n    session_class: Optional[Type[BaseSession]] = None\n\n    for sc in sessions_class:\n        if sc.name() == model_name:\n            session_class = sc\n            break\n\n    if session_class is None:\n        raise ValueError(f\"No session class found for model '{model_name}'\")\n\n    sess_opts = ort.SessionOptions()\n\n    if \"OMP_NUM_THREADS\" in os.environ:\n        threads = int(os.environ[\"OMP_NUM_THREADS\"])\n        sess_opts.inter_op_num_threads = threads\n        sess_opts.intra_op_num_threads = threads\n\n    return session_class(model_name, sess_opts, *args, **kwargs)\n"
  },
  {
    "path": "rembg/sessions/__init__.py",
    "content": "from __future__ import annotations\n\nfrom typing import Dict, List\n\nfrom .base import BaseSession\n\nsessions: Dict[str, type[BaseSession]] = {}\n\nfrom .birefnet_general import BiRefNetSessionGeneral\n\nsessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral\n\nfrom .birefnet_general_lite import BiRefNetSessionGeneralLite\n\nsessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite\n\nfrom .birefnet_portrait import BiRefNetSessionPortrait\n\nsessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait\n\nfrom .birefnet_dis import BiRefNetSessionDIS\n\nsessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS\n\nfrom .birefnet_hrsod import BiRefNetSessionHRSOD\n\nsessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD\n\nfrom .birefnet_cod import BiRefNetSessionCOD\n\nsessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD\n\nfrom .birefnet_massive import BiRefNetSessionMassive\n\nsessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive\n\nfrom .dis_anime import DisSession\n\nsessions[DisSession.name()] = DisSession\n\nfrom .dis_custom import DisCustomSession\n\nsessions[DisCustomSession.name()] = DisCustomSession\n\nfrom .dis_general_use import DisSession as DisSessionGeneralUse\n\nsessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse\n\nfrom .sam import SamSession\n\nsessions[SamSession.name()] = SamSession\n\nfrom .silueta import SiluetaSession\n\nsessions[SiluetaSession.name()] = SiluetaSession\n\nfrom .u2net_cloth_seg import Unet2ClothSession\n\nsessions[Unet2ClothSession.name()] = Unet2ClothSession\n\nfrom .u2net_custom import U2netCustomSession\n\nsessions[U2netCustomSession.name()] = U2netCustomSession\n\nfrom .u2net_human_seg import U2netHumanSegSession\n\nsessions[U2netHumanSegSession.name()] = U2netHumanSegSession\n\nfrom .u2net import U2netSession\n\nsessions[U2netSession.name()] = U2netSession\n\nfrom .u2netp import U2netpSession\n\nsessions[U2netpSession.name()] = U2netpSession\n\nfrom .bria_rmbg import BriaRmBgSession\n\nsessions[BriaRmBgSession.name()] = BriaRmBgSession\n\nfrom .ben_custom import BenCustomSession\n\nsessions[BenCustomSession.name()] = BenCustomSession\n\nsessions_names = list(sessions.keys())\nsessions_class = list(sessions.values())\n"
  },
  {
    "path": "rembg/sessions/base.py",
    "content": "import os\nfrom typing import Dict, List, Tuple\n\nimport numpy as np\nimport onnxruntime as ort\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\n\nclass BaseSession:\n    \"\"\"This is a base class for managing a session with a machine learning model.\"\"\"\n\n    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):\n        \"\"\"Initialize an instance of the BaseSession class.\"\"\"\n        self.model_name = model_name\n\n        if \"providers\" in kwargs and isinstance(kwargs[\"providers\"], list):\n            providers = kwargs.pop(\"providers\")\n        else:\n            device_type = ort.get_device()\n            if (\n                device_type == \"GPU\"\n                and \"CUDAExecutionProvider\" in ort.get_available_providers()\n            ):\n                providers = [\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n            elif (\n                device_type[0:3] == \"GPU\"\n                and \"ROCMExecutionProvider\" in ort.get_available_providers()\n            ):\n                providers = [\"ROCMExecutionProvider\", \"CPUExecutionProvider\"]\n            else:\n                providers = [\"CPUExecutionProvider\"]\n\n        self.inner_session = ort.InferenceSession(\n            str(self.__class__.download_models(*args, **kwargs)),\n            sess_options=sess_opts,\n            providers=providers,\n        )\n\n    def normalize(\n        self,\n        img: PILImage,\n        mean: Tuple[float, float, float],\n        std: Tuple[float, float, float],\n        size: Tuple[int, int],\n        *args,\n        **kwargs\n    ) -> Dict[str, np.ndarray]:\n        im = img.convert(\"RGB\").resize(size, Image.Resampling.LANCZOS)\n\n        im_ary = np.array(im)\n        im_ary = im_ary / max(np.max(im_ary), 1e-6)\n\n        tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))\n        tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]\n        tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]\n        tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]\n\n        tmpImg = tmpImg.transpose((2, 0, 1))\n\n        return {\n            self.inner_session.get_inputs()[0]\n            .name: np.expand_dims(tmpImg, 0)\n            .astype(np.float32)\n        }\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        raise NotImplementedError\n\n    @classmethod\n    def checksum_disabled(cls, *args, **kwargs):\n        return os.getenv(\"MODEL_CHECKSUM_DISABLED\", None) is not None\n\n    @classmethod\n    def u2net_home(cls, *args, **kwargs):\n        return os.path.expanduser(\n            os.getenv(\n                \"U2NET_HOME\", os.path.join(os.getenv(\"XDG_DATA_HOME\", \"~\"), \".u2net\")\n            )\n        )\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        raise NotImplementedError\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        raise NotImplementedError\n"
  },
  {
    "path": "rembg/sessions/ben_custom.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport onnxruntime as ort\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass BenCustomSession(BaseSession):\n    \"\"\"This is a class representing a custom session for the Ben model.\"\"\"\n\n    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):\n        \"\"\"\n        Initialize a new BenCustomSession object.\n\n        Parameters:\n            model_name (str): The name of the model.\n            sess_opts: The session options.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n        \"\"\"\n        model_path = kwargs.get(\"model_path\")\n        if model_path is None:\n            raise ValueError(\"model_path is required\")\n\n        super().__init__(model_name, sess_opts, *args, **kwargs)\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the mask image for the input image.\n\n        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.\n\n        Parameters:\n            img (PILImage): The input image.\n\n        Returns:\n            List[PILImage]: A list of PILImage objects representing the generated mask image.\n        \"\"\"\n\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Download the model files.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The absolute path to the model files.\n        \"\"\"\n        model_path = kwargs.get(\"model_path\")\n        if model_path is None:\n            raise ValueError(\"model_path is required\")\n\n        return os.path.abspath(os.path.expanduser(model_path))\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Get the name of the model.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"ben_custom\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_cod.py",
    "content": "import os\n\nimport pooch\n\nfrom . import BiRefNetSessionGeneral\n\n\nclass BiRefNetSessionCOD(BiRefNetSessionGeneral):\n    \"\"\"\n    This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.\n    \"\"\"\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-COD model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:f6d0d21ca89d287f17e7afe9f5fd3b45\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-COD session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-cod\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_dis.py",
    "content": "import os\n\nimport pooch\n\nfrom . import BiRefNetSessionGeneral\n\n\nclass BiRefNetSessionDIS(BiRefNetSessionGeneral):\n    \"\"\"\n    This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.\n    \"\"\"\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-DIS model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:2d4d44102b446f33a4ebb2e56c051f2b\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-DIS session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-dis\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_general.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass BiRefNetSessionGeneral(BaseSession):\n    \"\"\"\n    This class represents a BiRefNet-General session, which is a subclass of BaseSession.\n    \"\"\"\n\n    def sigmoid(self, mat):\n        return 1 / (1 + np.exp(-mat))\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the output masks for the input image using the inner session.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            List[PILImage]: The list of output masks.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)\n            ),\n        )\n\n        pred = self.sigmoid(ort_outs[0][:, 0, :, :])\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-General model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:7a35a0141cbbc80de11d9c9a28f52697\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-General session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-general\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_general_lite.py",
    "content": "import os\n\nimport pooch\n\nfrom . import BiRefNetSessionGeneral\n\n\nclass BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):\n    \"\"\"\n    This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.\n    \"\"\"\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:4fab47adc4ff364be1713e97b7e66334\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-General-Lite session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-general-lite\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_hrsod.py",
    "content": "import os\n\nimport pooch\n\nfrom . import BiRefNetSessionGeneral\n\n\nclass BiRefNetSessionHRSOD(BiRefNetSessionGeneral):\n    \"\"\"\n    This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.\n    \"\"\"\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:c017ade5de8a50ff0fd74d790d268dda\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-HRSOD session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-hrsod\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_massive.py",
    "content": "import os\n\nimport pooch\n\nfrom . import BiRefNetSessionGeneral\n\n\nclass BiRefNetSessionMassive(BiRefNetSessionGeneral):\n    \"\"\"\n    This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral.\n    \"\"\"\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-Massive model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:33e726a2136a3d59eb0fdf613e31e3e9\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-Massive session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-massive\"\n"
  },
  {
    "path": "rembg/sessions/birefnet_portrait.py",
    "content": "import os\n\nimport pooch\n\nfrom . import BiRefNetSessionGeneral\n\n\nclass BiRefNetSessionPortrait(BiRefNetSessionGeneral):\n    \"\"\"\n    This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral.\n    \"\"\"\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BiRefNet-Portrait model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:c3a64a6abf20250d090cd055f12a3b67\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the BiRefNet-Portrait session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"birefnet-portrait\"\n"
  },
  {
    "path": "rembg/sessions/bria_rmbg.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass BriaRmBgSession(BaseSession):\n    \"\"\"\n    This class represents a Bria-rmbg-2.0 session, which is a subclass of BaseSession.\n    \"\"\"\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the output masks for the input image using the inner session.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            List[PILImage]: The list of output masks.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)\n            ),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the BRIA-RMBG 2.0 model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"sha256:5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the Bria-rmbg session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"bria-rmbg\"\n"
  },
  {
    "path": "rembg/sessions/dis_anime.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass DisSession(BaseSession):\n    \"\"\"\n    This class represents a session for object detection.\n    \"\"\"\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Use a pre-trained model to predict the object in the given image.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            List[PILImage]: A list of predicted mask images.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Download the pre-trained models.\n\n        Parameters:\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The path of the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:6f184e756bb3bd901c8849220a83e38e\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Get the name of the pre-trained model.\n\n        Parameters:\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The name of the pre-trained model.\n        \"\"\"\n        return \"isnet-anime\"\n"
  },
  {
    "path": "rembg/sessions/dis_custom.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport onnxruntime as ort\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass DisCustomSession(BaseSession):\n    \"\"\"This is a class representing a custom session for the Dis model.\"\"\"\n\n    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):\n        \"\"\"\n        Initialize a new DisCustomSession object.\n\n        Parameters:\n            model_name (str): The name of the model.\n            sess_opts: The session options.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n        \"\"\"\n        model_path = kwargs.get(\"model_path\")\n        if model_path is None:\n            raise ValueError(\"model_path is required\")\n\n        super().__init__(model_name, sess_opts, *args, **kwargs)\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the mask image for the input image.\n\n        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.\n\n        Parameters:\n            img (PILImage): The input image.\n\n        Returns:\n            List[PILImage]: A list of PILImage objects representing the generated mask image.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Download the model files.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The absolute path to the model files.\n        \"\"\"\n        model_path = kwargs.get(\"model_path\")\n        if model_path is None:\n            raise ValueError(\"model_path is required\")\n\n        return os.path.abspath(os.path.expanduser(model_path))\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Get the name of the model.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"dis_custom\"\n"
  },
  {
    "path": "rembg/sessions/dis_general_use.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass DisSession(BaseSession):\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the mask image for the input image.\n\n        This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.\n\n        Parameters:\n            img (PILImage): The input image.\n\n        Returns:\n            List[PILImage]: A list of PILImage objects representing the generated mask image.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the pre-trained model file.\n\n        This class method downloads the pre-trained model file from a specified URL using the pooch library.\n\n        Parameters:\n            args: Additional positional arguments.\n            kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:fc16ebd8b0c10d971d3513d564d01e29\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the model.\n\n        This class method returns the name of the model.\n\n        Parameters:\n            args: Additional positional arguments.\n            kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"isnet-general-use\"\n"
  },
  {
    "path": "rembg/sessions/sam.py",
    "content": "import os\nfrom copy import deepcopy\nfrom typing import List\n\nimport numpy as np\nimport onnxruntime as ort\nimport pooch\nfrom jsonschema import validate\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\nfrom scipy.ndimage import map_coordinates\n\nfrom .base import BaseSession\n\n\ndef warp_affine(\n    image: np.ndarray, matrix: np.ndarray, output_shape: tuple\n) -> np.ndarray:\n    \"\"\"\n    Apply affine transformation to an image (matching cv2.warpAffine behavior).\n\n    cv2.warpAffine maps source coordinates to destination coordinates:\n        dst(M @ [x, y, 1]^T) = src(x, y)\n\n    So to fill dst(x', y'), we compute the inverse:\n        src_coords = M^(-1) @ [x', y', 1]^T\n\n    Args:\n        image: Input image (H, W) or (H, W, C)\n        matrix: 2x3 affine transformation matrix\n        output_shape: (height, width) of output\n\n    Returns:\n        Transformed image\n    \"\"\"\n    h, w = output_shape\n\n    # Build full 3x3 matrix and compute inverse\n    M_full = np.vstack([matrix, [0, 0, 1]])\n    M_inv = np.linalg.inv(M_full)[:2]\n\n    # Create output coordinate grid\n    cols = np.arange(w)\n    rows = np.arange(h)\n    x_coords, y_coords = np.meshgrid(cols, rows)\n\n    # Apply inverse transform to get source coordinates\n    src_x = M_inv[0, 0] * x_coords + M_inv[0, 1] * y_coords + M_inv[0, 2]\n    src_y = M_inv[1, 0] * x_coords + M_inv[1, 1] * y_coords + M_inv[1, 2]\n\n    if image.ndim == 2:\n        result = map_coordinates(\n            image.astype(np.float64), [src_y, src_x], order=1, mode=\"constant\", cval=0\n        )\n    else:\n        result = np.zeros((h, w, image.shape[2]), dtype=np.float64)\n        for c in range(image.shape[2]):\n            result[:, :, c] = map_coordinates(\n                image[:, :, c].astype(np.float64),\n                [src_y, src_x],\n                order=1,\n                mode=\"constant\",\n                cval=0,\n            )\n\n    return result.astype(image.dtype)\n\n\ndef get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):\n    scale = long_side_length * 1.0 / max(oldh, oldw)\n    newh, neww = oldh * scale, oldw * scale\n    neww = int(neww + 0.5)\n    newh = int(newh + 0.5)\n\n    return (newh, neww)\n\n\ndef apply_coords(coords: np.ndarray, original_size, target_length):\n    old_h, old_w = original_size\n    new_h, new_w = get_preprocess_shape(\n        original_size[0], original_size[1], target_length\n    )\n\n    coords = deepcopy(coords).astype(float)\n    coords[..., 0] = coords[..., 0] * (new_w / old_w)\n    coords[..., 1] = coords[..., 1] * (new_h / old_h)\n\n    return coords\n\n\ndef get_input_points(prompt):\n    points = []\n    labels = []\n\n    for mark in prompt:\n        if mark[\"type\"] == \"point\":\n            points.append(mark[\"data\"])\n            labels.append(mark[\"label\"])\n        elif mark[\"type\"] == \"rectangle\":\n            points.append([mark[\"data\"][0], mark[\"data\"][1]])\n            points.append([mark[\"data\"][2], mark[\"data\"][3]])\n            labels.append(2)\n            labels.append(3)\n\n    points, labels = np.array(points), np.array(labels)\n    return points, labels\n\n\ndef transform_masks(masks, original_size, transform_matrix):\n    output_masks = []\n\n    for batch in range(masks.shape[0]):\n        batch_masks = []\n        for mask_id in range(masks.shape[1]):\n            mask = masks[batch, mask_id]\n            mask = warp_affine(\n                mask,\n                transform_matrix[:2],\n                (original_size[0], original_size[1]),\n            )\n            batch_masks.append(mask)\n        output_masks.append(batch_masks)\n\n    return np.array(output_masks)\n\n\nclass SamSession(BaseSession):\n    \"\"\"\n    This class represents a session for the Sam model.\n\n    Args:\n        model_name (str): The name of the model.\n        sess_opts (ort.SessionOptions): The session options.\n        *args: Variable length argument list.\n        **kwargs: Arbitrary keyword arguments.\n    \"\"\"\n\n    def __init__(\n        self,\n        model_name: str,\n        sess_opts: ort.SessionOptions,\n        *args,\n        **kwargs,\n    ):\n        \"\"\"\n        Initialize a new SamSession with the given model name and session options.\n\n        Args:\n            model_name (str): The name of the model.\n            sess_opts (ort.SessionOptions): The session options.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n        \"\"\"\n        self.model_name = model_name\n\n        paths = self.__class__.download_models(*args, **kwargs)\n        self.encoder = ort.InferenceSession(\n            str(paths[0]),\n            sess_options=sess_opts,\n        )\n        self.decoder = ort.InferenceSession(\n            str(paths[1]),\n            sess_options=sess_opts,\n        )\n\n    def predict(\n        self,\n        img: PILImage,\n        *args,\n        **kwargs,\n    ) -> List[PILImage]:\n        \"\"\"\n        Predict masks for an input image.\n\n        This function takes an image as input and performs various preprocessing steps on the image. It then runs the image through an encoder to obtain an image embedding. The function also takes input labels and points as additional arguments. It concatenates the input points and labels with padding and transforms them. It creates an empty mask input and an indicator for no mask. The function then passes the image embedding, point coordinates, point labels, mask input, and has mask input to a decoder. The decoder generates masks based on the input and returns them as a list of images.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Additional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            List[PILImage]: A list of masks generated by the decoder.\n        \"\"\"\n        prompt = kwargs.get(\n            \"sam_prompt\",\n            [\n                {\n                    \"type\": \"point\",\n                    \"label\": 1,\n                    \"data\": [int(img.width / 2), int(img.height / 2)],\n                }\n            ],\n        )\n        schema = {\n            \"type\": \"array\",\n            \"items\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"type\": {\"type\": \"string\"},\n                    \"label\": {\"type\": \"integer\"},\n                    \"data\": {\n                        \"type\": \"array\",\n                        \"items\": {\"type\": \"number\"},\n                    },\n                },\n            },\n        }\n\n        validate(instance=prompt, schema=schema)\n\n        target_size = 1024\n        input_size = (684, 1024)\n        encoder_input_name = self.encoder.get_inputs()[0].name\n\n        img = img.convert(\"RGB\")\n        cv_image = np.array(img)\n        original_size = cv_image.shape[:2]\n\n        scale_x = input_size[1] / cv_image.shape[1]\n        scale_y = input_size[0] / cv_image.shape[0]\n        scale = min(scale_x, scale_y)\n\n        transform_matrix = np.array(\n            [\n                [scale, 0, 0],\n                [0, scale, 0],\n                [0, 0, 1],\n            ]\n        )\n\n        cv_image = warp_affine(\n            cv_image,\n            transform_matrix[:2],\n            (input_size[0], input_size[1]),\n        )\n\n        ## encoder\n\n        encoder_inputs = {\n            encoder_input_name: cv_image.astype(np.float32),\n        }\n\n        encoder_output = self.encoder.run(None, encoder_inputs)\n        image_embedding = encoder_output[0]\n\n        embedding = {\n            \"image_embedding\": image_embedding,\n            \"original_size\": original_size,\n            \"transform_matrix\": transform_matrix,\n        }\n\n        ## decoder\n\n        input_points, input_labels = get_input_points(prompt)\n        onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[\n            None, :, :\n        ]\n        onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[\n            None, :\n        ].astype(np.float32)\n        onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(\n            np.float32\n        )\n\n        onnx_coord = np.concatenate(\n            [\n                onnx_coord,\n                np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),\n            ],\n            axis=2,\n        )\n        onnx_coord = np.matmul(onnx_coord, transform_matrix.T)\n        onnx_coord = onnx_coord[:, :, :2].astype(np.float32)\n\n        onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n        onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n\n        decoder_inputs = {\n            \"image_embeddings\": image_embedding,\n            \"point_coords\": onnx_coord,\n            \"point_labels\": onnx_label,\n            \"mask_input\": onnx_mask_input,\n            \"has_mask_input\": onnx_has_mask_input,\n            \"orig_im_size\": np.array(input_size, dtype=np.float32),\n        }\n\n        masks, _, _ = self.decoder.run(None, decoder_inputs)\n        inv_transform_matrix = np.linalg.inv(transform_matrix)\n        masks = transform_masks(masks, original_size, inv_transform_matrix)\n\n        mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)\n        for m in masks[0, :, :, :]:\n            mask[m > 0.0] = [255, 255, 255]\n\n        return [Image.fromarray(mask).convert(\"L\")]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Class method to download ONNX model files.\n\n        This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method.\n\n        Parameters:\n            cls: The class object.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            tuple: A tuple containing the file paths of the downloaded encoder and decoder models.\n        \"\"\"\n        model_name = kwargs.get(\"sam_model\", \"sam_vit_b_01ec64\")\n        quant = kwargs.get(\"sam_quant\", False)\n\n        fname_encoder = f\"{model_name}.encoder.onnx\"\n        fname_decoder = f\"{model_name}.decoder.onnx\"\n\n        if quant:\n            fname_encoder = f\"{model_name}.encoder.quant.onnx\"\n            fname_decoder = f\"{model_name}.decoder.quant.onnx\"\n\n        pooch.retrieve(\n            f\"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}\",\n            None,\n            fname=fname_encoder,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        pooch.retrieve(\n            f\"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}\",\n            None,\n            fname=fname_decoder,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        if fname_encoder == \"sam_vit_h_4b8939.encoder.onnx\" and not os.path.exists(\n            os.path.join(\n                cls.u2net_home(*args, **kwargs), \"sam_vit_h_4b8939.encoder_data.bin\"\n            )\n        ):\n            content = bytearray()\n\n            for i in range(1, 4):\n                pooch.retrieve(\n                    f\"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin\",\n                    None,\n                    fname=f\"sam_vit_h_4b8939.encoder_data.{i}.bin\",\n                    path=cls.u2net_home(*args, **kwargs),\n                    progressbar=True,\n                )\n\n                fbin = os.path.join(\n                    cls.u2net_home(*args, **kwargs),\n                    f\"sam_vit_h_4b8939.encoder_data.{i}.bin\",\n                )\n                content.extend(open(fbin, \"rb\").read())\n                os.remove(fbin)\n\n            with open(\n                os.path.join(\n                    cls.u2net_home(*args, **kwargs),\n                    \"sam_vit_h_4b8939.encoder_data.bin\",\n                ),\n                \"wb\",\n            ) as fp:\n                fp.write(content)\n\n        return (\n            os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),\n            os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),\n        )\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Class method to return a string value.\n\n        This method returns the string value 'sam'.\n\n        Parameters:\n            cls: The class object.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The string value 'sam'.\n        \"\"\"\n        return \"sam\"\n"
  },
  {
    "path": "rembg/sessions/silueta.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass SiluetaSession(BaseSession):\n    \"\"\"This is a class representing a SiluetaSession object.\"\"\"\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predict the mask of the input image.\n\n        This method takes an image as input, preprocesses it, and performs a prediction to generate a mask. The generated mask is then post-processed and returned as a list of PILImage objects.\n\n        Parameters:\n            img (PILImage): The input image to be processed.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            List[PILImage]: A list of post-processed masks.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)\n            ),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Download the pre-trained model file.\n\n        This method downloads the pre-trained model file from a specified URL. The file is saved to the U2NET home directory.\n\n        Parameters:\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name()}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:55e59e0d8062d2f5d013f4725ee84782\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Return the name of the model.\n\n        This method returns the name of the Silueta model.\n\n        Parameters:\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"silueta\"\n"
  },
  {
    "path": "rembg/sessions/u2net.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass U2netSession(BaseSession):\n    \"\"\"\n    This class represents a U2net session, which is a subclass of BaseSession.\n    \"\"\"\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the output masks for the input image using the inner session.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            List[PILImage]: The list of output masks.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)\n            ),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred.clip(0, 1) * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the U2net model file from a specific URL and saves it.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model file.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:60024c5c889badc19c04ad937298a77b\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the U2net session.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the session.\n        \"\"\"\n        return \"u2net\"\n"
  },
  {
    "path": "rembg/sessions/u2net_cloth_seg.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\npalette1 = [\n    0,\n    0,\n    0,\n    255,\n    255,\n    255,\n    0,\n    0,\n    0,\n    0,\n    0,\n    0,\n]\n\npalette2 = [\n    0,\n    0,\n    0,\n    0,\n    0,\n    0,\n    255,\n    255,\n    255,\n    0,\n    0,\n    0,\n]\n\npalette3 = [\n    0,\n    0,\n    0,\n    0,\n    0,\n    0,\n    0,\n    0,\n    0,\n    255,\n    255,\n    255,\n]\n\n\nclass Unet2ClothSession(BaseSession):\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predict the cloth category of an image.\n\n        This method takes an image as input and predicts the cloth category of the image.\n        The method uses the inner_session to make predictions using a pre-trained model.\n        The predicted mask is then converted to an image and resized to match the size of the input image.\n        Depending on the cloth category specified in the method arguments, the method applies different color palettes to the mask and appends the resulting images to a list.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            List[PILImage]: A list of images representing the predicted masks.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)\n            ),\n        )\n\n        pred = np.argmax(ort_outs[0], axis=1, keepdims=True)\n        pred = np.squeeze(pred, 0)\n        pred = np.squeeze(pred, 0)\n\n        mask = Image.fromarray(pred.astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        masks = []\n\n        cloth_category = kwargs.get(\"cc\") or kwargs.get(\"cloth_category\")\n\n        def upper_cloth():\n            mask1 = mask.copy()\n            mask1.putpalette(palette1)\n            mask1 = mask1.convert(\"RGB\").convert(\"L\")\n            masks.append(mask1)\n\n        def lower_cloth():\n            mask2 = mask.copy()\n            mask2.putpalette(palette2)\n            mask2 = mask2.convert(\"RGB\").convert(\"L\")\n            masks.append(mask2)\n\n        def full_cloth():\n            mask3 = mask.copy()\n            mask3.putpalette(palette3)\n            mask3 = mask3.convert(\"RGB\").convert(\"L\")\n            masks.append(mask3)\n\n        if cloth_category == \"upper\":\n            upper_cloth()\n        elif cloth_category == \"lower\":\n            lower_cloth()\n        elif cloth_category == \"full\":\n            full_cloth()\n        else:\n            upper_cloth()\n            lower_cloth()\n            full_cloth()\n\n        return masks\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:2434d1f3cb744e0e49386c906e5a08bb\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        return \"u2net_cloth_seg\"\n"
  },
  {
    "path": "rembg/sessions/u2net_custom.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport onnxruntime as ort\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass U2netCustomSession(BaseSession):\n    \"\"\"This is a class representing a custom session for the U2net model.\"\"\"\n\n    def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):\n        \"\"\"\n        Initialize a new U2netCustomSession object.\n\n        Parameters:\n            model_name (str): The name of the model.\n            sess_opts (ort.SessionOptions): The session options.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Raises:\n            ValueError: If model_path is None.\n        \"\"\"\n        model_path = kwargs.get(\"model_path\")\n        if model_path is None:\n            raise ValueError(\"model_path is required\")\n\n        super().__init__(model_name, sess_opts, *args, **kwargs)\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predict the segmentation mask for the input image.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            List[PILImage]: A list of PILImage objects representing the segmentation mask.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)\n            ),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Download the model files.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The absolute path to the model files.\n        \"\"\"\n        model_path = kwargs.get(\"model_path\")\n        if model_path is None:\n            raise ValueError(\"model_path is required\")\n\n        return os.path.abspath(os.path.expanduser(model_path))\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Get the name of the model.\n\n        Parameters:\n            *args: Additional positional arguments.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"u2net_custom\"\n"
  },
  {
    "path": "rembg/sessions/u2net_human_seg.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass U2netHumanSegSession(BaseSession):\n    \"\"\"\n    This class represents a session for performing human segmentation using the U2Net model.\n    \"\"\"\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts human segmentation masks for the input image.\n\n        Parameters:\n            img (PILImage): The input image.\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            List[PILImage]: A list of predicted masks.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)\n            ),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the U2Net model weights.\n\n        Parameters:\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The path to the downloaded model weights.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:c09ddc2e0104f800e3e1bb4652583d1f\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the U2Net model.\n\n        Parameters:\n            *args: Variable length argument list.\n            **kwargs: Arbitrary keyword arguments.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"u2net_human_seg\"\n"
  },
  {
    "path": "rembg/sessions/u2netp.py",
    "content": "import os\nfrom typing import List\n\nimport numpy as np\nimport pooch\nfrom PIL import Image\nfrom PIL.Image import Image as PILImage\n\nfrom .base import BaseSession\n\n\nclass U2netpSession(BaseSession):\n    \"\"\"This class represents a session for using the U2netp model.\"\"\"\n\n    def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:\n        \"\"\"\n        Predicts the mask for the given image using the U2netp model.\n\n        Parameters:\n            img (PILImage): The input image.\n\n        Returns:\n            List[PILImage]: The predicted mask.\n        \"\"\"\n        ort_outs = self.inner_session.run(\n            None,\n            self.normalize(\n                img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)\n            ),\n        )\n\n        pred = ort_outs[0][:, 0, :, :]\n\n        ma = np.max(pred)\n        mi = np.min(pred)\n\n        pred = (pred - mi) / (ma - mi)\n        pred = np.squeeze(pred)\n\n        mask = Image.fromarray((pred * 255).astype(\"uint8\"), mode=\"L\")\n        mask = mask.resize(img.size, Image.Resampling.LANCZOS)\n\n        return [mask]\n\n    @classmethod\n    def download_models(cls, *args, **kwargs):\n        \"\"\"\n        Downloads the U2netp model.\n\n        Returns:\n            str: The path to the downloaded model.\n        \"\"\"\n        fname = f\"{cls.name(*args, **kwargs)}.onnx\"\n        pooch.retrieve(\n            \"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx\",\n            (\n                None\n                if cls.checksum_disabled(*args, **kwargs)\n                else \"md5:8e83ca70e441ab06c318d82300c84806\"\n            ),\n            fname=fname,\n            path=cls.u2net_home(*args, **kwargs),\n            progressbar=True,\n        )\n\n        return os.path.join(cls.u2net_home(*args, **kwargs), fname)\n\n    @classmethod\n    def name(cls, *args, **kwargs):\n        \"\"\"\n        Returns the name of the U2netp model.\n\n        Returns:\n            str: The name of the model.\n        \"\"\"\n        return \"u2netp\"\n"
  },
  {
    "path": "rembg.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"provenance\": [],\n      \"gpuType\": \"T4\"\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"accelerator\": \"GPU\"\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 3,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"hF9llNyHkiRB\",\n        \"outputId\": \"bd4e1cc0-f871-4c3f-d6e3-b503fe170f71\"\n      },\n      \"outputs\": [\n      ],\n      \"source\": [\n        \"! pip install \\\"rembg[gpu,cli]\\\"\\n\",\n        \"! git clone https://huggingface.co/spaces/KenjieDec/RemBG\\n\",\n        \"%cd RemBG\\n\",\n        \"!python app.py\"\n      ]\n    }\n  ]\n}\n"
  },
  {
    "path": "rembg.py",
    "content": "from rembg.cli import main\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "rembg.spec",
    "content": "# -*- mode: python ; coding: utf-8 -*-\nfrom PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs\n\ndatas = []\ndatas += collect_data_files('gradio_client')\ndatas += collect_data_files('gradio')\ndatas += collect_data_files('safehttpx')\ndatas += collect_data_files('groovy')\n\nbinaries = []\n\n# Collect onnxruntime (works for both CPU and GPU versions)\n# The pip packages are named differently (onnxruntime vs onnxruntime-gpu)\n# but both install the Python module as 'onnxruntime'\ntry:\n    datas += collect_data_files('onnxruntime')\n    binaries += collect_dynamic_libs('onnxruntime')\nexcept Exception:\n    pass\n\na = Analysis(\n    ['rembg.py'],\n    pathex=[],\n    binaries=binaries,\n    datas=datas,\n    hiddenimports=[\n        # Core dependencies\n        'numpy',\n        'PIL',\n        'scipy',\n        'scipy.ndimage',\n        'skimage',\n        'skimage.morphology',\n        'pymatting',\n        'pymatting.alpha',\n        'pymatting.foreground',\n        'pymatting.util',\n        'tqdm',\n        'pooch',\n        'jsonschema',\n        'onnxruntime',\n        # CLI dependencies\n        'click',\n        'uvicorn',\n        'fastapi',\n        'starlette',\n        'starlette.responses',\n        'aiohttp',\n        'asyncer',\n        'filetype',\n        'gradio',\n        'watchdog',\n        'sniffio',\n        'multipart',\n    ],\n    hookspath=[],\n    hooksconfig={},\n    runtime_hooks=[],\n    excludes=[],\n    noarchive=False,\n    module_collection_mode={\n        'gradio': 'py',\n    },\n)\npyz = PYZ(a.pure)\n\nexe = EXE(\n    pyz,\n    a.scripts,\n    [],\n    exclude_binaries=True,\n    name='rembg',\n    debug=False,\n    bootloader_ignore_signals=False,\n    strip=False,\n    upx=True,\n    console=True,\n    disable_windowed_traceback=False,\n    argv_emulation=False,\n    target_arch=None,\n    codesign_identity=None,\n    entitlements_file=None,\n)\ncoll = COLLECT(\n    exe,\n    a.binaries,\n    a.datas,\n    strip=False,\n    upx=True,\n    upx_exclude=[],\n    name='rembg',\n)\n"
  },
  {
    "path": "tests/test_remove.py",
    "content": "from io import BytesIO\nfrom pathlib import Path\n\nfrom imagehash import phash as hash_img\nfrom PIL import Image\n\nfrom rembg import new_session, remove\n\nhere = Path(__file__).parent.resolve()\nfailures_dir = here / \"failures\"\nfailures_dir.mkdir(exist_ok=True)\n\ndef test_remove():\n    kwargs = {\n        \"sam\": {\n            \"anime-girl-1\" : {\n                \"sam_prompt\" :[{\"type\": \"point\", \"data\": [400, 165], \"label\": 1}],\n            }\n        }\n    }\n\n    for model in [\n        \"u2net\",\n        \"u2netp\",\n        \"u2net_human_seg\",\n        \"u2net_cloth_seg\",\n        \"silueta\",\n        \"isnet-general-use\",\n        \"isnet-anime\",\n        \"sam\",\n        \"birefnet-general\",\n        \"birefnet-general-lite\",\n        \"birefnet-portrait\",\n        \"birefnet-dis\",\n        \"birefnet-hrsod\",\n        \"birefnet-cod\",\n        \"birefnet-massive\"\n    ]:\n        for picture in [\"anime-girl-1\"]:\n            image_path = Path(here / \"fixtures\" / f\"{picture}.jpg\")\n            image = image_path.read_bytes()\n\n            actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {}))\n            actual_hash = hash_img(Image.open(BytesIO(actual)))\n\n            expected_path = Path(here / \"results\" / f\"{picture}.{model}.png\")\n            # Uncomment to update the expected results\n            # f = open(expected_path, \"wb\")\n            # f.write(actual)\n            # f.close()\n\n            expected = expected_path.read_bytes()\n            expected_hash = hash_img(Image.open(BytesIO(expected)))\n\n            print(f\"image_path: {image_path}\")\n            print(f\"expected_path: {expected_path}\")\n            print(f\"actual_hash: {actual_hash}\")\n            print(f\"expected_hash: {expected_hash}\")\n            print(f\"actual_hash == expected_hash: {actual_hash == expected_hash}\")\n            print(\"---\\n\")\n\n            if actual_hash != expected_hash:\n                # Salva as imagens que falharam para comparação\n                actual_failure_path = failures_dir / f\"{picture}.{model}.actual.png\"\n                expected_failure_path = failures_dir / f\"{picture}.{model}.expected.png\"\n\n                with open(actual_failure_path, \"wb\") as f:\n                    f.write(actual)\n                with open(expected_failure_path, \"wb\") as f:\n                    f.write(expected)\n\n                print(f\"FAILURE: Saved comparison images to {failures_dir}\")\n\n            assert actual_hash == expected_hash\n"
  }
]