[
  {
    "path": ".conda/meta.yaml",
    "content": "{% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %}\n{% set project = pyproject.get('project') %}\n{% set urls = pyproject.get('project', {}).get('urls') %}\n{% set version = environ.get('BUILD_VERSION', '0.2.0.dev0') %}\npackage:\n  name: {{ project.get('name') }}\n  version: {{ version }}\n\nsource:\n  fn: {{ project.get('name') }}-{{ version }}}.tar.gz\n  url: ../dist/{{ project.get('name') }}-{{ version }}.tar.gz\n\nbuild:\n  noarch: python\n  script: python setup.py install --single-version-externally-managed --record=record.txt\n\nrequirements:\n  host:\n    - python>=3.8, <4.0\n    - setuptools\n\n  run:\n    - pytorch >=2.0.0, <3.0.0\n\ntest:\n  # Python imports\n  imports:\n    - torchscan\n    - torchscan.modules\n    - torchscan.process\n    - torchscan.utils\n  requires:\n    - python\n\nabout:\n  home: {{ urls.get('repository') }}\n  license: Apache 2.0\n  license_file: {{ project.get('license', {}).get('file') }}\n  summary: {{ project.get('description') }}\n  # description: |\n  #   {{ data['long_description'] | replace(\"\\n\", \"\\n    \") | replace(\"#\", '\\#')}}\n  doc_url: {{ urls.get('documentation') }}\n  dev_url: {{ urls.get('repository') }}\n"
  },
  {
    "path": ".github/FUNDING.yml",
    "content": "# These are supported funding model platforms\n\ngithub: frgfm\npatreon: # Replace with a single Patreon username\nopen_collective: # Replace with an OpenCollective account\nko_fi: # Replace with a single Ko-fi username\ntidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel\ncommunity_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry\nliberapay: # Replace with a single Liberapay username\nissuehunt: # Replace with a single IssueHunt username\notechie: # Replace with a single Otechie username\ncustom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yml",
    "content": "name: 🐛 Bug report\ndescription: Create a report to help us improve the library\nlabels: 'type: bug'\nassignees: frgfm\n\nbody:\n- type: markdown\n  attributes:\n    value: >\n      #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/frgfm/torch-cam/issues?q=is%3Aissue).\n- type: textarea\n  attributes:\n    label: Bug description\n    description: |\n      A clear and concise description of what the bug is.\n\n      Please explain the result you observed and the behavior you were expecting.\n    placeholder: |\n      A clear and concise description of what the bug is.\n  validations:\n    required: true\n\n- type: textarea\n  attributes:\n    label: Code snippet to reproduce the bug\n    description: |\n      Sample code to reproduce the problem.\n\n      Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability.\n    placeholder: |\n      ```python\n      Sample code to reproduce the problem\n      ```\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Error traceback\n    description: |\n      The error message you received running the code snippet, with the full traceback.\n\n      Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability.\n    placeholder: |\n      ```\n      The error message you got, with the full traceback.\n      ```\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Environment\n    description: |\n      Please run the following command and paste the output below.\n      ```sh\n      wget https://raw.githubusercontent.com/frgfm/torch-scan/main/.github/collect_env.py\n      # For security purposes, please check the contents of collect_env.py before running it.\n      python collect_env.py\n      ```\n  validations:\n    required: true\n- type: markdown\n  attributes:\n    value: >\n      Thanks for helping us improve the library!\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: true\ncontact_links:\n  - name: Usage questions\n    url: https://github.com/frgfm/torch-scan/discussions\n    about: Ask questions and discuss with other TorchCAM community members\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "content": "name: 🚀 Feature request\ndescription: Submit a proposal/request for a new feature\nlabels: 'type: enhancement'\nassignees: frgfm\n\nbody:\n- type: textarea\n  attributes:\n    label: 🚀 Feature\n    description: >\n      A clear and concise description of the feature proposal\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Motivation & pitch\n    description: >\n      Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *\"I'm working on X and would like Y to be possible\"*. If this is related to another GitHub issue, please link here too.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Alternatives\n    description: >\n      A description of any alternative solutions or features you've considered, if any.\n- type: textarea\n  attributes:\n    label: Additional context\n    description: >\n      Add any other context or screenshots about the feature request.\n- type: markdown\n  attributes:\n    value: >\n      Thanks for contributing 🎉\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# What does this PR do?\n\n<!--\nWell, hello there! Thank you for proposing modifications to the project.\n\nMake sure to have both a short descriptive title & explain your modifications with the relevant context. Make sure to include reference to Github issues it is related to. For the sake of keeping the library light, if you modified existing dependencies or added new ones, please state it clearly in your description.\n\n-->\n\n<!-- Remove if not applicable -->\n\nCloses # (issue)\n\n\n## Before submitting\n- [ ] Was this discussed/approved in a Github [issue](https://github.com/frgfm/torch-scan/issues?q=is%3Aissue) or a [discussion](https://github.com/frgfm/torch-scan/discussions)? Please add a link to it if that's the case.\n- [ ] You have read the [contribution guidelines](https://github.com/frgfm/torch-scan/blob/main/CONTRIBUTING.md#submitting-a-pull-request) and followed them in this PR.\n- [ ] Did you make sure to update the documentation with your changes? Here are the\n      [documentation guidelines](https://github.com/frgm/torch-scan/tree/main/docs).\n- [ ] Did you write any new necessary tests?\n"
  },
  {
    "path": ".github/collect_env.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\n\"\"\"\nBased on https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py\nThis script outputs relevant system environment info\nRun it with `python collect_env.py`.\n\"\"\"\n\nfrom __future__ import absolute_import, division, print_function, unicode_literals\n\nimport locale\nimport os\nimport re\nimport subprocess  # noqa S404\nimport sys\nfrom pathlib import Path\nfrom typing import NamedTuple\n\ntry:\n    import torchscan\n\n    TORCHSCAN_AVAILABLE = True\nexcept (ImportError, NameError, AttributeError, OSError):\n    TORCHSCAN_AVAILABLE = False\n\ntry:\n    import torch\n\n    TORCH_AVAILABLE = True\nexcept (ImportError, NameError, AttributeError, OSError):\n    TORCH_AVAILABLE = False\n\nPY3 = sys.version_info >= (3, 0)\n\n\n# System Environment Information\nclass SystemEnv(NamedTuple):\n    torchscan_version: str\n    torch_version: str\n    os: str\n    python_version: str\n    is_cuda_available: bool\n    cuda_runtime_version: str\n    nvidia_driver_version: str\n    nvidia_gpu_models: str\n    cudnn_version: str\n\n\ndef run(command):\n    \"\"\"Returns (return-code, stdout, stderr)\"\"\"\n    p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)\n    output, err = p.communicate()\n    rc = p.returncode\n    if PY3:\n        enc = locale.getpreferredencoding()\n        output = output.decode(enc)\n        err = err.decode(enc)\n    return rc, output.strip(), err.strip()\n\n\ndef run_and_read_all(run_lambda, command):\n    \"\"\"Runs command using run_lambda; reads and returns entire output if rc is 0\"\"\"\n    rc, out, _ = run_lambda(command)\n    if rc != 0:\n        return None\n    return out\n\n\ndef run_and_parse_first_match(run_lambda, command, regex):\n    \"\"\"Runs command using run_lambda, returns the first regex match if it exists\"\"\"\n    rc, out, _ = run_lambda(command)\n    if rc != 0:\n        return None\n    match = re.search(regex, out)\n    if match is None:\n        return None\n    return match.group(1)\n\n\ndef get_nvidia_driver_version(run_lambda):\n    if get_platform() == \"darwin\":\n        cmd = \"kextstat | grep -i cuda\"\n        return run_and_parse_first_match(run_lambda, cmd, r\"com[.]nvidia[.]CUDA [(](.*?)[)]\")\n    smi = get_nvidia_smi()\n    return run_and_parse_first_match(run_lambda, smi, r\"Driver Version: (.*?) \")\n\n\ndef get_gpu_info(run_lambda):\n    if get_platform() == \"darwin\":\n        if TORCH_AVAILABLE and torch.cuda.is_available():\n            return torch.cuda.get_device_name(None)\n        return None\n    smi = get_nvidia_smi()\n    uuid_regex = re.compile(r\" \\(UUID: .+?\\)\")\n    rc, out, _ = run_lambda(smi + \" -L\")\n    if rc != 0:\n        return None\n    # Anonymize GPUs by removing their UUID\n    return re.sub(uuid_regex, \"\", out)\n\n\ndef get_running_cuda_version(run_lambda):\n    return run_and_parse_first_match(run_lambda, \"nvcc --version\", r\"release .+ V(.*)\")\n\n\ndef get_cudnn_version(run_lambda):\n    \"\"\"This will return a list of libcudnn.so; it's hard to tell which one is being used\"\"\"\n    if get_platform() == \"win32\":\n        cudnn_cmd = 'where /R \"%CUDA_PATH%\\\\bin\" cudnn*.dll'\n    elif get_platform() == \"darwin\":\n        # CUDA libraries and drivers can be found in /usr/local/cuda/. See\n        # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install\n        # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac\n        # Use CUDNN_LIBRARY when cudnn library is installed elsewhere.\n        cudnn_cmd = \"ls /usr/local/cuda/lib/libcudnn*\"\n    else:\n        cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d\" \" -f1 | rev'\n    rc, out, _ = run_lambda(cudnn_cmd)\n    # find will return 1 if there are permission errors or if not found\n    if len(out) == 0 or rc not in (1, 0):\n        lib = os.environ.get(\"CUDNN_LIBRARY\")\n        if lib is not None and Path(lib).is_file():\n            return os.path.realpath(lib)\n        return None\n    files = set()\n    for fn in out.split(\"\\n\"):\n        fn = os.path.realpath(fn)  # eliminate symbolic links\n        if Path(fn).is_file():\n            files.add(fn)\n    if not files:\n        return None\n    # Alphabetize the result because the order is non-deterministic otherwise\n    files = sorted(files)\n    if len(files) == 1:\n        return files[0]\n    result = \"\\n\".join(files)\n    return \"Probably one of the following:\\n{}\".format(result)\n\n\ndef get_nvidia_smi():\n    # Note: nvidia-smi is currently available only on Windows and Linux\n    smi = \"nvidia-smi\"\n    if get_platform() == \"win32\":\n        system_root = os.environ.get(\"SYSTEMROOT\", \"C:\\\\Windows\")\n        program_files_root = os.environ.get(\"PROGRAMFILES\", \"C:\\\\Program Files\")\n        legacy_path = Path(program_files_root) / \"NVIDIA Corporation\" / \"NVSMI\" / smi\n        new_path = Path(system_root) / \"System32\" / smi\n        smis = [new_path, legacy_path]\n        for candidate_smi in smis:\n            if Path(candidate_smi).exists():\n                smi = '\"{}\"'.format(candidate_smi)\n                break\n    return smi\n\n\ndef get_platform():\n    if sys.platform.startswith(\"linux\"):\n        return \"linux\"\n    if sys.platform.startswith(\"win32\"):\n        return \"win32\"\n    if sys.platform.startswith(\"cygwin\"):\n        return \"cygwin\"\n    if sys.platform.startswith(\"darwin\"):\n        return \"darwin\"\n    return sys.platform\n\n\ndef get_mac_version(run_lambda):\n    return run_and_parse_first_match(run_lambda, \"sw_vers -productVersion\", r\"(.*)\")\n\n\ndef get_windows_version(run_lambda):\n    return run_and_read_all(run_lambda, \"wmic os get Caption | findstr /v Caption\")\n\n\ndef get_lsb_version(run_lambda):\n    return run_and_parse_first_match(run_lambda, \"lsb_release -a\", r\"Description:\\t(.*)\")\n\n\ndef check_release_file(run_lambda):\n    return run_and_parse_first_match(run_lambda, \"cat /etc/*-release\", r'PRETTY_NAME=\"(.*)\"')\n\n\ndef get_os(run_lambda):\n    platform = get_platform()\n\n    if platform in (\"win32\", \"cygwin\"):\n        return get_windows_version(run_lambda)\n\n    if platform == \"darwin\":\n        version = get_mac_version(run_lambda)\n        if version is None:\n            return None\n        return \"Mac OSX {}\".format(version)\n\n    if platform == \"linux\":\n        # Ubuntu/Debian based\n        desc = get_lsb_version(run_lambda)\n        if desc is not None:\n            return desc\n\n        # Try reading /etc/*-release\n        desc = check_release_file(run_lambda)\n        if desc is not None:\n            return desc\n\n        return platform\n\n    # Unknown platform\n    return platform\n\n\ndef get_env_info():\n    run_lambda = run\n\n    torchscan_str = torchscan.__version__ if TORCHSCAN_AVAILABLE else \"N/A\"\n\n    if TORCH_AVAILABLE:\n        torch_str = torch.__version__\n        cuda_available_str = torch.cuda.is_available()\n    else:\n        torch_str = cuda_available_str = \"N/A\"\n\n    return SystemEnv(\n        torchscan_version=torchscan_str,\n        torch_version=torch_str,\n        python_version=\".\".join(map(str, sys.version_info[:3])),\n        is_cuda_available=cuda_available_str,\n        cuda_runtime_version=get_running_cuda_version(run_lambda),\n        nvidia_gpu_models=get_gpu_info(run_lambda),\n        nvidia_driver_version=get_nvidia_driver_version(run_lambda),\n        cudnn_version=get_cudnn_version(run_lambda),\n        os=get_os(run_lambda),\n    )\n\n\nenv_info_fmt = \"\"\"\nTorchScan version: {torchscan_version}\nPyTorch version: {torch_version}\n\nOS: {os}\n\nPython version: {python_version}\nIs CUDA available: {is_cuda_available}\nCUDA runtime version: {cuda_runtime_version}\nGPU models and configuration: {nvidia_gpu_models}\nNvidia driver version: {nvidia_driver_version}\ncuDNN version: {cudnn_version}\n\"\"\".strip()\n\n\ndef pretty_str(envinfo):\n    def replace_nones(dct, replacement=\"Could not collect\"):\n        for key in dct:\n            if dct[key] is not None:\n                continue\n            dct[key] = replacement\n        return dct\n\n    def replace_bools(dct, true=\"Yes\", false=\"No\"):\n        for key in dct:\n            if dct[key] is True:\n                dct[key] = true\n            elif dct[key] is False:\n                dct[key] = false\n        return dct\n\n    def maybe_start_on_next_line(string):\n        # If `string` is multiline, prepend a \\n to it.\n        if string is not None and len(string.split(\"\\n\")) > 1:\n            return \"\\n{}\\n\".format(string)\n        return string\n\n    mutable_dict = envinfo._asdict()\n\n    # If nvidia_gpu_models is multiline, start on the next line\n    mutable_dict[\"nvidia_gpu_models\"] = maybe_start_on_next_line(envinfo.nvidia_gpu_models)\n\n    # If the machine doesn't have CUDA, report some fields as 'No CUDA'\n    dynamic_cuda_fields = [\n        \"cuda_runtime_version\",\n        \"nvidia_gpu_models\",\n        \"nvidia_driver_version\",\n    ]\n    all_cuda_fields = [*dynamic_cuda_fields, \"cudnn_version\"]\n    all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None for field in dynamic_cuda_fields)\n    if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:\n        for field in all_cuda_fields:\n            mutable_dict[field] = \"No CUDA\"\n\n    # Replace True with Yes, False with No\n    mutable_dict = replace_bools(mutable_dict)\n\n    # Replace all None objects with 'Could not collect'\n    mutable_dict = replace_nones(mutable_dict)\n\n    return env_info_fmt.format(**mutable_dict)\n\n\ndef get_pretty_env_info():\n    \"\"\"Collects environment information for debugging purposes\n\n    Returns:\n        str: environment information\n    \"\"\"\n    return pretty_str(get_env_info())\n\n\ndef main():\n    print(\"Collecting environment information...\")\n    output = get_pretty_env_info()\n    print(output)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where the package manifests are located.\n# Please see the documentation for all configuration options:\n# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file\n\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"monthly\"\n      time: \"06:00\"\n      timezone: \"Europe/Paris\"\n    groups:\n      gh-actions:\n        patterns:\n          - \"*\"\n    reviewers:\n      - \"frgfm\"\n    assignees:\n      - \"frgfm\"\n  - package-ecosystem: \"pip\"\n    directory: \"/\"\n    schedule:\n      interval: \"daily\"\n      time: \"06:00\"\n      timezone: \"Europe/Paris\"\n    reviewers:\n      - \"frgfm\"\n    assignees:\n      - \"frgfm\"\n    allow:\n      - dependency-name: \"ruff\"\n      - dependency-name: \"mypy\"\n      - dependency-name: \"pre-commit\"\n"
  },
  {
    "path": ".github/labeler.yml",
    "content": "'module: crawler':\n- changed-files:\n  - any-glob-to-any-file: torchscan/crawler.py\n\n'module: modules':\n- changed-files:\n  - any-glob-to-any-file: torchscan/modules/*\n\n'module: process':\n- changed-files:\n  - any-glob-to-any-file: torchscan/process/*\n\n'module: utils':\n- changed-files:\n  - any-glob-to-any-file: torchscan/utils.py\n\n'ext: docs':\n- changed-files:\n  - any-glob-to-any-file: docs/*\n\n'ext: scripts':\n- changed-files:\n  - any-glob-to-any-file: scripts/*\n\n'ext: tests':\n- changed-files:\n  - any-glob-to-any-file: tests/*\n\n'topic: ci':\n- changed-files:\n  - any-glob-to-any-file: .github/*\n\n'topic: docs':\n- changed-files:\n  - any-glob-to-any-file:\n    - README.md\n    - CONTRIBUTING.md\n    - CODFE_OF_CONDUCT.md\n    - CITATION.cff\n    - LICENSE\n\n'topic: build':\n- changed-files:\n  - any-glob-to-any-file:\n    - setup.py\n    - pyproject.toml\n\n'topic: style':\n- changed-files:\n  - any-glob-to-any-file: .pre-commit-config.yaml\n"
  },
  {
    "path": ".github/release.yml",
    "content": "changelog:\n  exclude:\n    labels:\n      - ignore-for-release\n  categories:\n    - title: Breaking Changes 🛠\n      labels:\n        - \"type: breaking change\"\n    # NEW FEATURES\n    - title: New Features 🚀\n      labels:\n        - \"type: feat\"\n    # BUG FIXES\n    - title: Bug Fixes 🐛\n      labels:\n        - \"type: fix\"\n    # IMPROVEMENTS\n    - title: Improvements\n      labels:\n        - \"type: improvement\"\n    # MISC\n    - title: Miscellaneous\n      labels:\n        - \"type: misc\"\n"
  },
  {
    "path": ".github/verify_labels.py",
    "content": "# Copyright (C) 2022-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\n\"\"\"\nBorrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py\nThis script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in\n'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled,\nthis script is a no-op.\nNote: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision\nwith no labeling responsibility, so we don't want to bother them.\n\"\"\"\n\nfrom typing import Any, Set, Tuple\n\nimport requests\n\n# For a PR to be properly labeled it should have one primary label and one secondary label\n\n# Should specify the type of change\nPRIMARY_LABELS = {\n    \"type: new feature\",\n    \"type: bug\",\n    \"type: enhancement\",\n    \"type: misc\",\n}\n\n# Should specify what has been modified\nSECONDARY_LABELS = {\n    \"topic: documentation\",\n    \"module: modules\",\n    \"module: process\",\n    \"module: crawler\",\n    \"module: utils\",\n    \"ext: docs\",\n    \"ext: scripts\",\n    \"ext: tests\",\n    \"topic: build\",\n    \"topic: ci\",\n}\n\nGH_ORG = \"frgfm\"\nGH_REPO = \"torch-scan\"\n\n\ndef query_repo(cmd: str, *, accept) -> Any:\n    response = requests.get(\n        f\"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}\", headers={\"Accept\": accept}, timeout=5\n    )\n    return response.json()\n\n\ndef get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]:\n    # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request\n    data = query_repo(f\"pulls/{pr_number}\", accept=\"application/vnd.github.v3+json\")\n    merger = data.get(\"merged_by\", {}).get(\"login\")\n    labels = {label[\"name\"] for label in data[\"labels\"]}\n    return merger, labels\n\n\ndef main(args):\n    merger, labels = get_pr_merger_and_labels(args.pr)\n    is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels))\n    if isinstance(merger, str) and not is_properly_labeled:\n        print(f\"@{merger}\")\n\n\ndef parse_args():\n    import argparse\n\n    parser = argparse.ArgumentParser(\n        description=\"PR label checker\", formatter_class=argparse.ArgumentDefaultsHelpFormatter\n    )\n\n    parser.add_argument(\"pr\", type=int, help=\"PR number\")\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": ".github/workflows/builds.yml",
    "content": "name: builds\n\non:\n  push:\n    branches: main\n  pull_request:\n    branches: main\n\njobs:\n  build:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu-latest, macos-latest, windows-latest]\n        python: [3.8, 3.9, '3.10', 3.11, 3.12]\n        exclude:\n          - os: macos-latest\n            python: 3.8\n          - os: macos-latest\n            python: 3.9\n          - os: macos-latest\n            python: '3.10'\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Install package\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e .\n      - name: Import package\n        run: python -c \"import torchscan; print(torchscan.__version__)\"\n\n  pypi:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.11\n          architecture: x64\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system setuptools wheel twine --upgrade\n      - run: |\n          python setup.py sdist bdist_wheel\n          twine check dist/*\n\n  conda:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: conda-incubator/setup-miniconda@v3\n        with:\n          auto-update-conda: true\n          python-version: \"3.11\"\n      - name: Install dependencies\n        shell: bash -el {0}\n        run: conda install -y conda-build conda-verify\n      - name: Build conda\n        shell: bash -el {0}\n        run: |\n          python setup.py sdist\n          mkdir conda-dist\n          conda env list\n          conda build .conda/ -c pytorch --output-folder conda-dist\n          ls -l conda-dist/noarch/*tar.bz2\n"
  },
  {
    "path": ".github/workflows/doc-status.yml",
    "content": "name: GH-Pages Status\non:\n  page_build\n\njobs:\n  see-page-build-payload:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.11\n          architecture: x64\n      - name: check status\n        run: |\n          import os\n          status, errormsg = os.getenv('STATUS'), os.getenv('ERROR')\n          if status != 'built': raise AssertionError(f\"There was an error building the page on GitHub pages.\\n\\nStatus: {status}\\n\\nError messsage: {errormsg}\")\n        shell: python\n        env:\n          STATUS: ${{ github.event.build.status }}\n          ERROR: ${{ github.event.build.error.message }}\n"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "name: docs\non:\n  push:\n    branches: main\n\njobs:\n  docs-deploy:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n        python: [3.9]\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e \".[docs]\"\n\n      - name: Build documentation\n        run: cd docs && bash build.sh\n\n      - name: Documentation sanity check\n        run: test -e docs/build/index.html || exit\n\n      - name: Install SSH Client 🔑\n        uses: webfactory/ssh-agent@v0.9.0\n        with:\n          ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }}\n\n      - name: Deploy to Github Pages\n        uses: JamesIves/github-pages-deploy-action@v4\n        with:\n          BRANCH: gh-pages\n          FOLDER: 'docs/build'\n          COMMIT_MESSAGE: '[skip ci] Documentation updates'\n          CLEAN: true\n          SSH: true\n"
  },
  {
    "path": ".github/workflows/pr-labels.yml",
    "content": "name: pr-labels\n\non:\n  pull_request:\n    branches: main\n    types: closed\n\njobs:\n  is-properly-labeled:\n    if: github.event.pull_request.merged == true\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n      - name: Install requests\n        run: pip install requests\n      - name: Process commit and find merger responsible for labeling\n        id: commit\n        run: echo \"::set-output name=merger::$(python .github/verify_labels.py ${{ github.event.pull_request.number }})\"\n      - name: Comment PR\n        uses: actions/github-script@7.0.1\n        if: ${{ steps.commit.outputs.merger != '' }}\n        with:\n          github-token: ${{ secrets.GITHUB_TOKEN }}\n          script: |\n            const { issue: { number: issue_number }, repo: { owner, repo }  } = context;\n            github.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/frgfm/torch-cam/blob/main/.github/verify_labels.py' });\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: publish\n\non:\n  release:\n    types: [published]\n\njobs:\n  pypi:\n    if: \"!github.event.release.prerelease\"\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.11\n          architecture: x64\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system setuptools wheel twine --upgrade\n      - name: Build and publish\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: |\n          echo \"BUILD_VERSION=${GITHUB_REF#refs/*/}\" | cut -c 2- >> $GITHUB_ENV\n          python setup.py sdist bdist_wheel\n          twine check dist/*\n          twine upload dist/*\n\n  pypi-check:\n    if: \"!github.event.release.prerelease\"\n    runs-on: ubuntu-latest\n    needs: pypi\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.11\n          architecture: x64\n      - name: Install package\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system torchscan\n          python -c \"import torchscan; print(torchscan.__version__)\"\n\n  conda:\n    if: \"!github.event.release.prerelease\"\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - name: Miniconda setup\n        uses: conda-incubator/setup-miniconda@v3\n        with:\n          auto-update-conda: true\n          python-version: 3.11\n      - name: Install dependencies\n        shell: bash -el {0}\n        run: conda install -y conda-build conda-verify anaconda-client\n      - name: Build and publish\n        shell: bash -el {0}\n        env:\n          ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }}\n        run: |\n          echo \"BUILD_VERSION=${GITHUB_REF#refs/*/}\" | cut -c 2- >> $GITHUB_ENV\n          python setup.py sdist\n          mkdir conda-dist\n          conda build .conda/ -c pytorch --output-folder conda-dist\n          ls -l conda-dist/noarch/*tar.bz2\n          anaconda upload conda-dist/noarch/*tar.bz2\n\n  conda-check:\n    if: \"!github.event.release.prerelease\"\n    runs-on: ubuntu-latest\n    needs: conda\n    steps:\n      - name: Miniconda setup\n        uses: conda-incubator/setup-miniconda@v3\n        with:\n          auto-update-conda: true\n          python-version: 3.11\n          auto-activate-base: true\n      - name: Install package\n        shell: bash -el {0}\n        run: |\n          conda install -c frgfm torchscan\n          python -c \"import torchscan; print(torchscan.__version__)\"\n"
  },
  {
    "path": ".github/workflows/pull_requests.yml",
    "content": "name: pull_requests\n\non:\n  pull_request:\n    branches: main\n\njobs:\n  docs:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.9\n          architecture: x64\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e \".[docs]\"\n\n      - name: Build documentation\n        run: cd docs && bash build.sh\n\n      - name: Documentation sanity check\n        run: test -e docs/build/index.html || exit\n\n  triage:\n    permissions:\n      contents: read\n      pull-requests: write\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/labeler@v5\n      with:\n        repo-token: \"${{ secrets.GITHUB_TOKEN }}\"\n"
  },
  {
    "path": ".github/workflows/style.yml",
    "content": "name: style\n\non:\n  push:\n    branches: main\n  pull_request:\n    branches: main\n\njobs:\n  ruff:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n        python: [3.11]\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Run ruff\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e '.[quality]'\n          ruff --version\n          ruff check --diff .\n\n  mypy:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n        python: [3.11]\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Run mypy\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e '.[quality]'\n          mypy --version\n          mypy\n\n  ruff-format:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n        python: [3.11]\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Run ruff\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e '.[quality]'\n          ruff --version\n          ruff format --check --diff .\n\n  precommit-hooks:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n        python: [3.11]\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Run pre-commit hooks\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e '.[quality]'\n          git checkout -b temp\n          pre-commit install\n          pre-commit --version\n          pre-commit run --all-files\n"
  },
  {
    "path": ".github/workflows/tests.yml",
    "content": "name: tests\n\non:\n  push:\n    branches: main\n  pull_request:\n    branches: main\n\njobs:\n  pytest:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n        python: [3.11]\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n      - uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python }}\n          architecture: x64\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade uv\n          uv pip install --system -e \".[test]\" --upgrade\n      - name: Run unittests\n        run: pytest --cov=torchscan --cov-report xml tests/\n      - uses: actions/upload-artifact@v4\n        with:\n          name: coverage-reports\n          path: ./coverage.xml\n\n  codecov-upload:\n    runs-on: ubuntu-latest\n    needs: pytest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/download-artifact@v4\n      - name: Upload coverage to Codecov\n        uses: codecov/codecov-action@v5\n        with:\n          token: ${{ secrets.CODECOV_TOKEN }}\n          flags: unittests\n          directory: ./coverage-reports\n          fail_ci_if_error: true\n\n  headers:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest]\n    steps:\n      - uses: actions/checkout@v4\n        with:\n          persist-credentials: false\n      - name: Check the headers\n        uses: frgfm/validate-python-headers@main\n        with:\n          license: 'Apache-2.0'\n          owner: 'François-Guillaume Fernandez'\n          starting-year: 2020\n          folders: 'torchscan,scripts,docs,.github'\n          ignores: 'version.py,__init__.py'\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\nconda-dist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# Package version\ntorchscan/version.py\n# Conda distribution\nconda-dist/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "default_language_version:\n    python: python3.11\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n      - id: check-added-large-files\n      - id: check-ast\n      - id: check-case-conflict\n      - id: check-json\n      - id: check-merge-conflict\n      - id: check-symlinks\n      - id: check-toml\n      - id: check-xml\n      - id: check-yaml\n        exclude: .conda\n      - id: debug-statements\n        language_version: python3\n      - id: end-of-file-fixer\n      - id: no-commit-to-branch\n        args: ['--branch', 'main']\n      - id: requirements-txt-fixer\n      - id: trailing-whitespace\n  - repo: https://github.com/charliermarsh/ruff-pre-commit\n    rev: 'v0.6.4'\n    hooks:\n      - id: ruff\n        args:\n          - --fix\n      - id: ruff-format\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, religion, or sexual identity\nand orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the\n  overall community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or\n  advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email\n  address, without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\nfg-feedback@protonmail.com.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series\nof actions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or\npermanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior,  harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within\nthe community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.0, available at\nhttps://www.contributor-covenant.org/version/2/0/code_of_conduct.html.\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct\nenforcement ladder](https://github.com/mozilla/diversity).\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see the FAQ at\nhttps://www.contributor-covenant.org/faq. Translations are available at\nhttps://www.contributor-covenant.org/translations.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to torchscan\n\nEverything you need to know to contribute efficiently to the project.\n\nWhatever the way you wish to contribute to the project, please respect the [code of conduct](CODE_OF_CONDUCT.md).\n\n\n## Codebase structure\n\n- [torchscan](https://github.com/frgfm/torch-scan/blob/main/torchscan) - The actual torchscan library\n- [tests](https://github.com/frgfm/torch-scan/blob/main/tests) - Python unit tests\n- [docs](https://github.com/frgfm/torch-scan/blob/main/docs) - Sphinx documentation building\n- [scripts](https://github.com/frgfm/torch-scan/blob/main/scripts) - Example and utilities scripts\n\n\n\n## Continuous Integration\n\nThis project uses the following integrations to ensure proper codebase maintenance:\n\n- [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage\n- [Codacy](https://www.codacy.com/) - analyzes commits for code quality\n- [Codecov](https://codecov.io/) - reports back coverage results\n\nAs a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code.\n\n\n## Feedback\n\n### Feature requests & bug report\n\nWhether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/frgfm/torch-scan/issues).\n\nFirst, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in.\n\n### Questions\n\nIf you are wondering how to do something with TorchScan, or a more general question, you should consider checking out Github [discussions](https://github.com/frgfm/torch-scan/discussions). See it as a Q&A forum, or the TorchScan-specific StackOverflow!\n\n\n\n## Submitting a Pull Request\n\n### Preparing your local branch\n\n1 - Fork this [repository](https://github.com/frgfm/torch-scan) by clicking on the \"Fork\" button at the top right of the page. This will create a copy of the project under your GitHub account (cf. [Fork a repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo)).\n\n2 - [Clone your fork](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) to your local disk and set the upstream to this repo\n```shell\ngit clone git@github.com:<YOUR_GITHUB_ACCOUNT>/torch-scan.git\ncd torch-scan\ngit remote add upstream https://github.com/frgfm/torch-scan.git\n```\n\n3 - You should not work on the `main` branch, so let's create a new one\n```shell\ngit checkout -b a-short-description\n```\n\n4 - You only have to set your development environment now. First uninstall any existing installation of the library with `pip uninstall torch-scan`, then:\n```shell\npip install -e \".[dev]\"\npre-commit install\n```\n\n### Developing your feature\n\n#### Commits\n\n- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later.\n- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/)\n\n#### Unit tests\n\nIn order to run the same unit tests as the CI workflows, you can run unittests locally:\n\n```shell\nmake test\n```\n\n#### Code quality\n\nThe CI will also run some sanity checks (header format, dependency consistency, etc.), which you can run as follows:\n\n```shell\nmake quality\n```\n\nThis will read `pyproject.toml` and run:\n- lint checking, formatting ([ruff](https://docs.astral.sh/ruff/))\n- type annotation checking ([mypy](https://github.com/python/mypy))\n\nYou can apply automatic fix to most of those by running:\n\n```shell\nmake style\n```\n\n### Submit your modifications\n\nPush your last modifications to your remote branch\n```shell\ngit push -u origin a-short-description\n```\n\nThen [open a Pull Request](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) from your fork's branch. Follow the instructions of the Pull Request template and then click on \"Create a pull request\".\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": "# this target runs checks on all files\nquality:\n\truff format --check .\n\truff check .\n\tmypy\n\n# this target runs checks on all files and potentially modifies some of them\nstyle:\n\truff format .\n\truff check --fix .\n\n# Run tests for the library\ntest:\n\tpytest --cov=torchscan tests/\n\n# Build documentation for current version\nsingle-docs:\n\tsphinx-build docs/source docs/_build -a\n\n# Check that docs can build\nfull-docs:\n\tcd docs && bash build.sh\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <img src=\"https://github.com/frgfm/torch-scan/releases/download/v0.1.1/logo_text.png\" width=\"30%\">\n</p>\n\n<p align=\"center\">\n  <a href=\"https://github.com/frgfm/torch-scan/actions/workflows/builds.yml\">\n    <img alt=\"CI Status\" src=\"https://img.shields.io/github/actions/workflow/status/frgfm/torch-scan/builds.yml?branch=main&label=CI&logo=github&style=flat-square\">\n  </a>\n  <a href=\"https://github.com/astral-sh/ruff\">\n    <img src=\"https://img.shields.io/badge/Linter-Ruff-FCC21B?style=flat-square&logo=ruff&logoColor=white\" alt=\"ruff\">\n  </a>\n  <a href=\"https://github.com/astral-sh/ruff\">\n    <img src=\"https://img.shields.io/badge/Formatter-Ruff-FCC21B?style=flat-square&logo=Python&logoColor=white\" alt=\"ruff\">\n  </a>\n  <a href=\"https://www.codacy.com/gh/frgfm/torch-scan/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=frgfm/torch-scan&amp;utm_campaign=Badge_Grade\"><img src=\"https://app.codacy.com/project/badge/Grade/9dc68e8bfce34d9dbc8b44a350e9adc7\"/></a>\n  <a href=\"https://codecov.io/gh/frgfm/torch-scan\">\n    <img src=\"https://img.shields.io/codecov/c/github/frgfm/torch-scan.svg?logo=codecov&style=flat-square&label=Coverage\" alt=\"Test coverage percentage\">\n  </a>\n</p>\n<p align=\"center\">\n  <a href=\"https://pypi.org/project/torchscan/\">\n    <img src=\"https://img.shields.io/pypi/v/torchscan.svg?logo=PyPI&logoColor=fff&style=flat-square&label=PyPI\" alt=\"PyPi Version\">\n  </a>\n  <a href=\"https://anaconda.org/frgfm/torchscan\">\n    <img src=\"https://img.shields.io/conda/v/frgfm/torchscan.svg?logo=anaconda&label=Conda&logoColor=fff&style=flat-square\" alt=\"Conda Version\">\n  </a>\n  <img src=\"https://img.shields.io/pypi/pyversions/torchscan.svg?logo=Python&label=Python&logoColor=fff&style=flat-square\" alt=\"pyversions\">\n  <a href=\"https://github.com/frgfm/torch-scan/blob/main/LICENSE\">\n    <img src=\"https://img.shields.io/github/license/frgfm/torch-scan.svg?label=License&logoColor=fff&style=flat-square\" alt=\"License\">\n  </a>\n</p>\n<p align=\"center\">\n  <a href=\"https://frgfm.github.io/torch-scan\">\n    <img src=\"https://img.shields.io/github/actions/workflow/status/frgfm/torch-scan/docs.yml?branch=main&label=Documentation&logo=read-the-docs&logoColor=white&style=flat-square\" alt=\"Documentation Status\">\n  </a>\n</p>\n\n\nThe very useful [summary](https://www.tensorflow.org/api_docs/python/tf/keras/Model#summary) method of `tf.keras.Model` but for PyTorch, with more useful information.\n\n\n## Quick Tour\n\n### Inspecting your PyTorch architecture\n\nSimilarly to the `torchsummary` implementation, `torchscan` brings useful module information into readable format. For nested complex architectures, you can use a maximum depth of display as follows:\n\n```python\nfrom torchvision.models import densenet121\nfrom torchscan import summary\n\nmodel = densenet121().eval().cuda()\nsummary(model, (3, 224, 224), max_depth=2)\n```\n\nwhich would yield\n\n```shell\n__________________________________________________________________________________________\nLayer                        Type                  Output Shape              Param #\n==========================================================================================\ndensenet                     DenseNet              (-1, 1000)                0\n├─features                   Sequential            (-1, 1024, 7, 7)          0\n|    └─conv0                 Conv2d                (-1, 64, 112, 112)        9,408\n|    └─norm0                 BatchNorm2d           (-1, 64, 112, 112)        257\n|    └─relu0                 ReLU                  (-1, 64, 112, 112)        0\n|    └─pool0                 MaxPool2d             (-1, 64, 56, 56)          0\n|    └─denseblock1           _DenseBlock           (-1, 256, 56, 56)         338,316\n|    └─transition1           _Transition           (-1, 128, 28, 28)         33,793\n|    └─denseblock2           _DenseBlock           (-1, 512, 28, 28)         930,072\n|    └─transition2           _Transition           (-1, 256, 14, 14)         133,121\n|    └─denseblock3           _DenseBlock           (-1, 1024, 14, 14)        2,873,904\n|    └─transition3           _Transition           (-1, 512, 7, 7)           528,385\n|    └─denseblock4           _DenseBlock           (-1, 1024, 7, 7)          2,186,272\n|    └─norm5                 BatchNorm2d           (-1, 1024, 7, 7)          4,097\n├─classifier                 Linear                (-1, 1000)                1,025,000\n==========================================================================================\nTrainable params: 7,978,856\nNon-trainable params: 0\nTotal params: 7,978,856\n------------------------------------------------------------------------------------------\nModel size (params + buffers): 30.76 Mb\nFramework & CUDA overhead: 423.57 Mb\nTotal RAM usage: 454.32 Mb\n------------------------------------------------------------------------------------------\nFloating Point Operations on forward: 5.74 GFLOPs\nMultiply-Accumulations on forward: 2.87 GMACs\nDirect memory accesses on forward: 2.90 GDMAs\n__________________________________________________________________________________________\n```\n\nResults are aggregated to the selected depth for improved readability.\n\nFor reference, here are explanations of a few acronyms:\n\n- **FLOPs**: floating-point operations (not to be confused with FLOPS which is FLOPs per second)\n- **MACs**: mutiply-accumulate operations (cf. [wikipedia](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation))\n- **DMAs**: direct memory accesses (many argue that it is more relevant than FLOPs or MACs to compare model inference speeds cf. [wikipedia](https://en.wikipedia.org/wiki/Direct_memory_access))\n\n\n\nAdditionally, for highway nets (models without multiple branches / skip connections), `torchscan` supports receptive field estimation.\n\n```python\nfrom torchvision.models import vgg16\nfrom torchscan import summary\n\nmodel = vgg16().eval().cuda()\nsummary(model, (3, 224, 224), receptive_field=True, max_depth=0)\n```\n\nwhich will add the layer's receptive field (relatively to the last convolutional layer) to the summary.\n\n\n## Setup\n\nPython 3.8 (or newer) and [pip](https://pip.pypa.io/en/stable/)/[conda](https://docs.conda.io/en/latest/miniconda.html) are required to install Torchscan.\n\n### Stable release\n\nYou can install the last stable release of the package using [pypi](https://pypi.org/project/torch-scan/) as follows:\n\n```shell\npip install torchscan\n```\n\nor using [conda](https://anaconda.org/frgfm/torchscan):\n\n```shell\nconda install -c frgfm torchscan\n```\n\n### Developer installation\n\nAlternatively, if you wish to use the latest features of the project that haven't made their way to a release yet, you can install the package from source:\n\n```shell\ngit clone https://github.com/frgfm/torch-scan.git\npip install -e torch-scan/.\n```\n\n\n## Benchmark\n\nBelow are the results for classification models supported by `torchvision` for a single image with 3 color channels of size `224x224` (apart from  `inception_v3`   which uses `299x299`).\n\n| Model              | Params (M) | FLOPs (G) | MACs (G) | DMAs (G) | RF   |\n| ------------------ | ---------- | --------- | -------- | -------- | ---- |\n| alexnet            | 61.1       | 1.43      | 0.71     | 0.72     | 195  |\n| googlenet          | 6.62       | 3.01      | 1.51     | 1.53     | --   |\n| vgg11              | 132.86     | 15.23     | 7.61     | 7.64     | 150  |\n| vgg11_bn           | 132.87     | 15.26     | 7.63     | 7.66     | 150  |\n| vgg13              | 133.05     | 22.63     | 11.31    | 11.35    | 156  |\n| vgg13_bn           | 133.05     | 22.68     | 11.33    | 11.37    | 156  |\n| vgg16              | 138.36     | 30.96     | 15.47    | 15.52    | 212  |\n| vgg16_bn           | 138.37     | 31.01     | 15.5     | 15.55    | 212  |\n| vgg19              | 143.67     | 39.28     | 19.63    | 19.69    | 268  |\n| vgg19_bn           | 143.68     | 39.34     | 19.66    | 19.72    | 268  |\n| resnet18           | 11.69      | 3.64      | 1.82     | 1.84     | --   |\n| resnet34           | 21.8       | 7.34      | 3.67     | 3.7      | --   |\n| resnet50           | 25.56      | 8.21      | 4.11     | 4.15     | --   |\n| resnet101          | 44.55      | 15.66     | 7.83     | 7.9      | --   |\n| resnet152          | 60.19      | 23.1      | 11.56    | 11.65    | --   |\n| inception_v3       | 27.16      | 11.45     | 5.73     | 5.76     | --   |\n| squeezenet1_0      | 1.25       | 1.64      | 0.82     | 0.83     | --   |\n| squeezenet1_1      | 1.24       | 0.7       | 0.35     | 0.36     | --   |\n| wide_resnet50_2    | 68.88      | 22.84     | 11.43    | 11.51    | --   |\n| wide_resnet101_2   | 126.89     | 45.58     | 22.8     | 22.95    | --   |\n| densenet121        | 7.98       | 5.74      | 2.87     | 2.9      | --   |\n| densenet161        | 28.68      | 15.59     | 7.79     | 7.86     | --   |\n| densenet169        | 14.15      | 6.81      | 3.4      | 3.44     | --   |\n| densenet201        | 20.01      | 8.7       | 4.34     | 4.39     | --   |\n| resnext50_32x4d    | 25.03      | 8.51      | 4.26     | 4.3      | --   |\n| resnext101_32x8d   | 88.79      | 32.93     | 16.48    | 16.61    | --   |\n| mobilenet_v2       | 3.5        | 0.63      | 0.31     | 0.32     | --   |\n| shufflenet_v2_x0_5 | 1.37       | 0.09      | 0.04     | 0.05     | --   |\n| shufflenet_v2_x1_0 | 2.28       | 0.3       | 0.15     | 0.15     | --   |\n| shufflenet_v2_x1_5 | 3.5        | 0.6       | 0.3      | 0.31     | --   |\n| shufflenet_v2_x2_0 | 7.39       | 1.18      | 0.59     | 0.6      | --   |\n| mnasnet0_5         | 2.22       | 0.22      | 0.11     | 0.12     | --   |\n| mnasnet0_75        | 3.17       | 0.45      | 0.23     | 0.24     | --   |\n| mnasnet1_0         | 4.38       | 0.65      | 0.33     | 0.34     | --   |\n| mnasnet1_3         | 6.28       | 1.08      | 0.54     | 0.56     | --   |\n\nThe above results were produced using the `scripts/benchmark.py` script.\n\n*Note: receptive field computation is currently only valid for highway nets.*\n\n\n\n## What else\n\n### Documentation\n\nThe full package documentation is available [here](https://frgfm.github.io/torch-scan/) for detailed specifications.\n\n\n### Example script\n\nAn example script is provided for you to benchmark torchvision models using the library:\n\n```shell\npython scripts/benchmark.py\n```\n\n\n## Credits\n\nThis project is developed and maintained by the repo owner, but the implementation was inspired or helped by the following contributions:\n\n- [Pytorch summary](https://github.com/sksq96/pytorch-summary): existing PyTorch porting of `tf.keras.Model.summary`\n- [Torchstat](https://github.com/Swall0w/torchstat): another module inspection tool\n- [Flops counter Pytorch](https://github.com/sovrasov/flops-counter.pytorch): operation counter tool\n- [THOP](https://github.com/Lyken17/pytorch-OpCounter): PyTorch Op counter\n- Number of operations and memory estimation articles by [Matthijs Hollemans](https://machinethink.net/blog/how-fast-is-my-model/), and [Sicara](https://www.sicara.ai/blog/2019-28-10-deep-learning-memory-usage-and-pytorch-optimization-tricks)\n- [Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440)\n\n\n## Citation\n\nIf you wish to cite this project, feel free to use this [BibTeX](http://www.bibtex.org/) reference:\n\n```bibtex\n@misc{torchscan2020,\n    title={Torchscan: meaningful module insights},\n    author={François-Guillaume Fernandez},\n    year={2020},\n    month={March},\n    publisher = {GitHub},\n    howpublished = {\\url{https://github.com/frgfm/torch-scan}}\n}\n```\n\n\n## Contributing\n\nAny sort of contribution is greatly appreciated!\n\nYou can find a short guide in [`CONTRIBUTING`](CONTRIBUTING.md) to help grow this project!\n\n\n\n## License\n\nDistributed under the Apache 2.0 License. See [`LICENSE`](LICENSE) for more information.\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/README.md",
    "content": "# Changing the documentation\n\nThe documentation of this project is built using `sphinx`. In order to install all the build dependencies, run the following command from the root folder of the repository:\n```shell\npip install -e \".[docs]\"\n```\n\n---\n**NOTE**\n\nYou are only generating the documentation to inspect it locally. Only the source files are pushed to the remote repository, the documentation will be built automatically by the CI.\n\n---\n\n## Build the documentation\n\n### Latest version\n\nIn most cases, you will only be changing the documentation of the latest version (dev version). In this case, you can build the documentation (the HTML files) with the following command:\n\n```shell\nsphinx-build docs/source docs/_build -a\n```\n\nThen open `docs/_build/index.html` in your web browser to navigate in it.\n\n\n### Multi-version documentation\n\nIn rare cases, you might want to modify the documentation for other versions. You will then have to build the documentation for the multiple versions of the package, which you can do by running this command from the `docs` folder:\n```shell\nbash build.sh\n```\n"
  },
  {
    "path": "docs/build.sh",
    "content": "function deploy_doc(){\n    if [ ! -z \"$1\" ]\n    then\n        git checkout $1\n    fi\n    COMMIT=$(git rev-parse --short HEAD)\n    echo \"Creating doc at commit\" $COMMIT \"and pushing to folder $2\"\n    pip install -U ..\n    if [ ! -z \"$2\" ]\n    then\n        if [ \"$2\" == \"latest\" ]; then\n            echo \"Pushing main\"\n            sphinx-build source build/$2 -a\n        elif [ -d build/$2 ]; then\n            echo \"Directory\" $2 \"already exists\"\n        else\n            echo \"Pushing version\" $2\n            cp -r _static source/ && cp _conf.py source/conf.py\n            sphinx-build source build/$2 -a\n        fi\n    else\n        echo \"Pushing stable\"\n        cp -r _static source/ && cp _conf.py source/conf.py\n        sphinx-build source build -a\n    fi\n    git checkout source/ && git clean -f source/\n}\n\n# exit when any command fails\nset -e\n# You can find the commit for each tag on https://github.com/frgfm/torch-scan/tags\nif [ -d build ]; then rm -Rf build; fi\nmkdir build\ncp -r source/_static .\ncp source/conf.py _conf.py\ngit fetch --all --tags --unshallow\ndeploy_doc \"\" latest\ndeploy_doc \"7ac9c839\" v0.1.0\ndeploy_doc \"900eb166\" v0.1.1\ndeploy_doc \"29fa4ed1\" # v0.1.2 Latest stable release\nrm -rf _build _static\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=source\r\nset BUILDDIR=build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.http://sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/source/_static/css/custom.css",
    "content": "h1 {\n    font-size: 200%;\n}\n\n/* Github button */\n\n.github-repo {\n    display: flex;\n    justify-content: center;\n}\n\n/* Version control */\n\n.version-button {\n    color: gray;\n    border: none;\n    padding: 5px;\n    font-size: 15px;\n    cursor: pointer;\n}\n\n.version-button:hover, .version-button:focus {\n    color: white;\n    background-color: gray;\n}\n\n.version-dropdown {\n    display: none;\n    min-width: 160px;\n    overflow: auto;\n    font-size: 15px;\n}\n\n.version-dropdown a {\n    color: gray;\n    padding: 3px 4px;\n    text-decoration: none;\n    display: block;\n}\n\n.version-dropdown a:hover {\n    color: white;\n    background-color: gray;\n}\n\n.version-show {\n    display: block;\n}\n"
  },
  {
    "path": "docs/source/_static/js/custom.js",
    "content": "// Based on https://github.com/huggingface/transformers/blob/master/docs/source/_static/js/custom.js\n\n\n// These two things need to be updated at each release for the version selector.\n// Last stable version\nconst stableVersion = \"v0.1.2\"\n// Dictionary doc folder to label. The last stable version should have an empty key.\nconst versionMapping = {\n    \"latest\": \"latest\",\n    \"\": \"v0.1.2 (stable)\",\n    \"v0.1.1\": \"v0.1.1\",\n    \"v0.1.0\": \"v0.1.0\",\n}\n\nfunction addGithubButton() {\n    const div = `\n        <div class=\"github-repo\">\n            <a\n                class=\"github-button\"\n                href=\"https://github.com/frgfm/torch-scan\"\n                data-size=\"large\"\n                data-show-count=\"true\"\n                aria-label=\"Star frgfm/torch-scan on GitHub\">Star</a>\n        </div>\n    `;\n    document.querySelector(\".sidebar-brand\").insertAdjacentHTML('afterend', div);\n}\n\nfunction addVersionControl() {\n    // To grab the version currently in view, we parse the url\n    const parts = location.toString().split('/');\n    let versionIndex = parts.length - 2;\n    // Index page may not have a last part with filename.html so we need to go up\n    if (parts[parts.length - 1] != \"\" && ! parts[parts.length - 1].match(/\\.html$|^search.html?/)) {\n        versionIndex = parts.length - 1;\n    }\n    const version = parts[versionIndex];\n\n    // Menu with all the links,\n    const versionMenu = document.createElement(\"div\");\n\n    const htmlLines = [];\n    for (const [key, value] of Object.entries(versionMapping)) {\n        let baseUrlIndex = (version == \"torch-scan\") ? versionIndex + 1: versionIndex;\n        var urlParts = parts.slice(0, baseUrlIndex);\n        if (key != \"\") {\n            urlParts = urlParts.concat([key]);\n        }\n        urlParts = urlParts.concat(parts.slice(versionIndex+1));\n        htmlLines.push(`<a href=\"${urlParts.join('/')}\">${value}</a>`);\n    }\n\n    versionMenu.classList.add(\"version-dropdown\");\n    versionMenu.innerHTML = htmlLines.join('\\n');\n\n    // Button for version selection\n    const versionButton = document.createElement(\"div\");\n    versionButton.classList.add(\"version-button\");\n    let label = (version == \"torch-scan\") ? stableVersion : version\n    versionButton.innerText = label.concat(\" ▼\");\n\n    // Toggle the menu when we click on the button\n    versionButton.addEventListener(\"click\", () => {\n        versionMenu.classList.toggle(\"version-show\");\n    });\n\n    // Hide the menu when we click elsewhere\n    window.addEventListener(\"click\", (event) => {\n        if (event.target != versionButton){\n            versionMenu.classList.remove('version-show');\n        }\n    });\n\n    // Container\n    const div = document.createElement(\"div\");\n    div.appendChild(versionButton);\n    div.appendChild(versionMenu);\n    div.style.paddingTop = '5px';\n    div.style.paddingBottom = '5px';\n    div.style.display = 'block';\n    div.style.textAlign = 'center';\n\n    const scrollDiv = document.querySelector(\".sidebar-brand\");\n    scrollDiv.insertBefore(div, scrollDiv.children[1]);\n}\n\n/*!\n * github-buttons v2.2.10\n * (c) 2019 なつき\n * @license BSD-2-Clause\n */\n/**\n * modified to run programmatically\n */\nfunction parseGithubButtons (){\"use strict\";var e=window.document,t=e.location,o=window.encodeURIComponent,r=window.decodeURIComponent,n=window.Math,a=window.HTMLElement,i=window.XMLHttpRequest,l=\"https://unpkg.com/github-buttons@2.2.10/dist/buttons.html\",c=i&&i.prototype&&\"withCredentials\"in i.prototype,d=c&&a&&a.prototype.attachShadow&&!a.prototype.attachShadow.prototype,s=function(e,t,o){e.addEventListener?e.addEventListener(t,o):e.attachEvent(\"on\"+t,o)},u=function(e,t,o){e.removeEventListener?e.removeEventListener(t,o):e.detachEvent(\"on\"+t,o)},h=function(e,t,o){var r=function(n){return u(e,t,r),o(n)};s(e,t,r)},f=function(e,t,o){var r=function(n){if(t.test(e.readyState))return u(e,\"readystatechange\",r),o(n)};s(e,\"readystatechange\",r)},p=function(e){return function(t,o,r){var n=e.createElement(t);if(o)for(var a in o){var i=o[a];null!=i&&(null!=n[a]?n[a]=i:n.setAttribute(a,i))}if(r)for(var l=0,c=r.length;l<c;l++){var d=r[l];n.appendChild(\"string\"==typeof d?e.createTextNode(d):d)}return n}},g=p(e),b=function(e){var t;return function(){t||(t=1,e.apply(this,arguments))}},m=\"body{margin:0}a{color:#24292e;text-decoration:none;outline:0}.octicon{display:inline-block;vertical-align:text-top;fill:currentColor}.widget{ display:inline-block;overflow:hidden;font-family:-apple-system, BlinkMacSystemFont, \\\"Segoe UI\\\", Helvetica, Arial, sans-serif;font-size:0;white-space:nowrap;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.btn,.social-count{display:inline-block;height:14px;padding:2px 5px;font-size:11px;font-weight:600;line-height:14px;vertical-align:bottom;cursor:pointer;border:1px solid #c5c9cc;border-radius:0.25em}.btn{background-color:#eff3f6;background-image:-webkit-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:-moz-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:linear-gradient(180deg, #fafbfc, #eff3f6 90%);background-position:-1px -1px;background-repeat:repeat-x;background-size:110% 110%;border-color:rgba(27,31,35,0.2);-ms-filter:\\\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')\\\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')}.btn:active{background-color:#e9ecef;background-image:none;border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);box-shadow:inset 0 0.15em 0.3em rgba(27,31,35,0.15)}.btn:focus,.btn:hover{background-color:#e6ebf1;background-image:-webkit-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:-moz-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:linear-gradient(180deg, #f0f3f6, #e6ebf1 90%);border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);-ms-filter:\\\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')\\\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')}.social-count{position:relative;margin-left:5px;background-color:#fff}.social-count:focus,.social-count:hover{color:#0366d6}.social-count b,.social-count i{position:absolute;top:50%;left:0;display:block;width:0;height:0;margin:-4px 0 0 -4px;border:solid transparent;border-width:4px 4px 4px 0;_line-height:0;_border-top-color:red !important;_border-bottom-color:red !important;_border-left-color:red !important;_filter:chroma(color=red)}.social-count b{border-right-color:#c5c9cc}.social-count i{margin-left:-3px;border-right-color:#fff}.lg .btn,.lg .social-count{height:16px;padding:5px 10px;font-size:12px;line-height:16px}.lg .social-count{margin-left:6px}.lg .social-count b,.lg .social-count i{margin:-5px 0 0 -5px;border-width:5px 5px 5px 0}.lg .social-count i{margin-left:-4px}\\n\",v={\"mark-github\":{width:16,height:16,path:'<path fill-rule=\"evenodd\" d=\"M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z\"/>'},eye:{width:16,height:16,path:'<path fill-rule=\"evenodd\" d=\"M8.06 2C3 2 0 8 0 8s3 6 8.06 6C13 14 16 8 16 8s-3-6-7.94-6zM8 12c-2.2 0-4-1.78-4-4 0-2.2 1.8-4 4-4 2.22 0 4 1.8 4 4 0 2.22-1.78 4-4 4zm2-4c0 1.11-.89 2-2 2-1.11 0-2-.89-2-2 0-1.11.89-2 2-2 1.11 0 2 .89 2 2z\"/>'},star:{width:14,height:16,path:'<path fill-rule=\"evenodd\" d=\"M14 6l-4.9-.64L7 1 4.9 5.36 0 6l3.6 3.26L2.67 14 7 11.67 11.33 14l-.93-4.74L14 6z\"/>'},\"repo-forked\":{width:10,height:16,path:'<path fill-rule=\"evenodd\" d=\"M8 1a1.993 1.993 0 0 0-1 3.72V6L5 8 3 6V4.72A1.993 1.993 0 0 0 2 1a1.993 1.993 0 0 0-1 3.72V6.5l3 3v1.78A1.993 1.993 0 0 0 5 15a1.993 1.993 0 0 0 1-3.72V9.5l3-3V4.72A1.993 1.993 0 0 0 8 1zM2 4.2C1.34 4.2.8 3.65.8 3c0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3 10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3-10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2z\"/>'},\"issue-opened\":{width:14,height:16,path:'<path fill-rule=\"evenodd\" d=\"M7 2.3c3.14 0 5.7 2.56 5.7 5.7s-2.56 5.7-5.7 5.7A5.71 5.71 0 0 1 1.3 8c0-3.14 2.56-5.7 5.7-5.7zM7 1C3.14 1 0 4.14 0 8s3.14 7 7 7 7-3.14 7-7-3.14-7-7-7zm1 3H6v5h2V4zm0 6H6v2h2v-2z\"/>'},\"cloud-download\":{width:16,height:16,path:'<path fill-rule=\"evenodd\" d=\"M9 12h2l-3 3-3-3h2V7h2v5zm3-8c0-.44-.91-3-4.5-3C5.08 1 3 2.92 3 5 1.02 5 0 6.52 0 8c0 1.53 1 3 3 3h3V9.7H3C1.38 9.7 1.3 8.28 1.3 8c0-.17.05-1.7 1.7-1.7h1.3V5c0-1.39 1.56-2.7 3.2-2.7 2.55 0 3.13 1.55 3.2 1.8v1.2H12c.81 0 2.7.22 2.7 2.2 0 2.09-2.25 2.2-2.7 2.2h-2V11h2c2.08 0 4-1.16 4-3.5C16 5.06 14.08 4 12 4z\"/>'}},w={},x=function(e,t,o){var r=p(e.ownerDocument),n=e.appendChild(r(\"style\",{type:\"text/css\"}));n.styleSheet?n.styleSheet.cssText=m:n.appendChild(e.ownerDocument.createTextNode(m));var a,l,d=r(\"a\",{className:\"btn\",href:t.href,target:\"_blank\",innerHTML:(a=t[\"data-icon\"],l=/^large$/i.test(t[\"data-size\"])?16:14,a=(\"\"+a).toLowerCase().replace(/^octicon-/,\"\"),{}.hasOwnProperty.call(v,a)||(a=\"mark-github\"),'<svg version=\"1.1\" width=\"'+l*v[a].width/v[a].height+'\" height=\"'+l+'\" viewBox=\"0 0 '+v[a].width+\" \"+v[a].height+'\" class=\"octicon octicon-'+a+'\" aria-hidden=\"true\">'+v[a].path+\"</svg>\"),\"aria-label\":t[\"aria-label\"]||void 0},[\" \",r(\"span\",{},[t[\"data-text\"]||\"\"])]);/\\.github\\.com$/.test(\".\"+d.hostname)?/^https?:\\/\\/((gist\\.)?github\\.com\\/[^\\/?#]+\\/[^\\/?#]+\\/archive\\/|github\\.com\\/[^\\/?#]+\\/[^\\/?#]+\\/releases\\/download\\/|codeload\\.github\\.com\\/)/.test(d.href)&&(d.target=\"_top\"):(d.href=\"#\",d.target=\"_self\");var u,h,g,x,y=e.appendChild(r(\"div\",{className:\"widget\"+(/^large$/i.test(t[\"data-size\"])?\" lg\":\"\")},[d]));/^(true|1)$/i.test(t[\"data-show-count\"])&&\"github.com\"===d.hostname&&(u=d.pathname.replace(/^(?!\\/)/,\"/\").match(/^\\/([^\\/?#]+)(?:\\/([^\\/?#]+)(?:\\/(?:(subscription)|(fork)|(issues)|([^\\/?#]+)))?)?(?:[\\/?#]|$)/))&&!u[6]?(u[2]?(h=\"/repos/\"+u[1]+\"/\"+u[2],u[3]?(x=\"subscribers_count\",g=\"watchers\"):u[4]?(x=\"forks_count\",g=\"network\"):u[5]?(x=\"open_issues_count\",g=\"issues\"):(x=\"stargazers_count\",g=\"stargazers\")):(h=\"/users/\"+u[1],g=x=\"followers\"),function(e,t){var o=w[e]||(w[e]=[]);if(!(o.push(t)>1)){var r=b(function(){for(delete w[e];t=o.shift();)t.apply(null,arguments)});if(c){var n=new i;s(n,\"abort\",r),s(n,\"error\",r),s(n,\"load\",function(){var e;try{e=JSON.parse(n.responseText)}catch(e){return void r(e)}r(200!==n.status,e)}),n.open(\"GET\",e),n.send()}else{var a=this||window;a._=function(e){a._=null,r(200!==e.meta.status,e.data)};var l=p(a.document)(\"script\",{async:!0,src:e+(/\\?/.test(e)?\"&\":\"?\")+\"callback=_\"}),d=function(){a._&&a._({meta:{}})};s(l,\"load\",d),s(l,\"error\",d),l.readyState&&f(l,/de|m/,d),a.document.getElementsByTagName(\"head\")[0].appendChild(l)}}}.call(this,\"https://api.github.com\"+h,function(e,t){if(!e){var n=t[x];y.appendChild(r(\"a\",{className:\"social-count\",href:t.html_url+\"/\"+g,target:\"_blank\",\"aria-label\":n+\" \"+x.replace(/_count$/,\"\").replace(\"_\",\" \").slice(0,n<2?-1:void 0)+\" on GitHub\"},[r(\"b\"),r(\"i\"),r(\"span\",{},[(\"\"+n).replace(/\\B(?=(\\d{3})+(?!\\d))/g,\",\")])]))}o&&o(y)})):o&&o(y)},y=window.devicePixelRatio||1,C=function(e){return(y>1?n.ceil(n.round(e*y)/y*2)/2:n.ceil(e))||0},F=function(e,t){e.style.width=t[0]+\"px\",e.style.height=t[1]+\"px\"},k=function(t,r){if(null!=t&&null!=r)if(t.getAttribute&&(t=function(e){for(var t={href:e.href,title:e.title,\"aria-label\":e.getAttribute(\"aria-label\")},o=[\"icon\",\"text\",\"size\",\"show-count\"],r=0,n=o.length;r<n;r++){var a=\"data-\"+o[r];t[a]=e.getAttribute(a)}return null==t[\"data-text\"]&&(t[\"data-text\"]=e.textContent||e.innerText),t}(t)),d){var a=g(\"span\",{title:t.title||void 0});x(a.attachShadow({mode:\"closed\"}),t,function(){r(a)})}else{var i=g(\"iframe\",{src:\"javascript:0\",title:t.title||void 0,allowtransparency:!0,scrolling:\"no\",frameBorder:0});F(i,[0,0]),i.style.border=\"none\";var c=function(){var a,d=i.contentWindow;try{a=d.document.body}catch(t){return void e.body.appendChild(i.parentNode.removeChild(i))}u(i,\"load\",c),x.call(d,a,t,function(e){var a=function(e){var t=e.offsetWidth,o=e.offsetHeight;if(e.getBoundingClientRect){var r=e.getBoundingClientRect();t=n.max(t,C(r.width)),o=n.max(o,C(r.height))}return[t,o]}(e);i.parentNode.removeChild(i),h(i,\"load\",function(){F(i,a)}),i.src=l+\"#\"+(i.name=function(e){var t=[];for(var r in e){var n=e[r];null!=n&&t.push(o(r)+\"=\"+o(n))}return t.join(\"&\")}(t)),r(i)})};s(i,\"load\",c),e.body.appendChild(i)}};t.protocol+\"//\"+t.host+t.pathname===l?x(e.body,function(e){for(var t={},o=e.split(\"&\"),n=0,a=o.length;n<a;n++){var i=o[n];if(\"\"!==i){var l=i.split(\"=\");t[r(l[0])]=null!=l[1]?r(l.slice(1).join(\"=\")):void 0}}return t}(window.name||t.hash.replace(/^#/,\"\"))):function(t){if(/m/.test(e.readyState)||!/g/.test(e.readyState)&&!e.documentElement.doScroll)setTimeout(t);else if(e.addEventListener){var o=b(t);h(e,\"DOMContentLoaded\",o),h(window,\"load\",o)}else f(e,/m/,t)}(function(){for(var t=e.querySelectorAll?e.querySelectorAll(\"a.github-button\"):function(){for(var t=[],o=e.getElementsByTagName(\"a\"),r=0,n=o.length;r<n;r++)~(\" \"+o[r].className+\" \").replace(/[ \\t\\n\\f\\r]+/g,\" \").indexOf(\" github-button \")&&t.push(o[r]);return t}(),o=0,r=t.length;o<r;o++)!function(e){k(e,function(t){e.parentNode.replaceChild(t,e)})}(t[o])})};\n\nfunction onLoad() {\n    addVersionControl();\n    addGithubButton();\n    parseGithubButtons();\n}\n\nwindow.addEventListener(\"load\", onLoad);\n"
  },
  {
    "path": "docs/source/changelog.rst",
    "content": "Changelog\n=========\n\n\nv0.1.2 (2022-08-03)\n-------------------\nRelease note: `v0.1.2 <https://github.com/frgfm/torch-scan/releases/tag/v0.1.2>`_\n\nv0.1.1 (2020-08-04)\n-------------------\nRelease note: `v0.1.1 <https://github.com/frgfm/torch-scan/releases/tag/v0.1.1>`_\n\nv0.1.0 (2020-05-21)\n-------------------\nRelease note: `v0.1.0 <https://github.com/frgfm/torch-scan/releases/tag/v0.1.0>`_\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\n# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport sys\nfrom datetime import datetime\nfrom pathlib import Path\n\nsys.path.insert(0, Path().cwd().parent.parent)\nimport torchscan\n\n# -- Project information -----------------------------------------------------\n\nmaster_doc = \"index\"\nproject = \"torchscan\"\ncopyright = f\"2020-{datetime.now().year}, François-Guillaume Fernandez\"\nauthor = \"François-Guillaume Fernandez\"\n\n# The full version, including alpha/beta/rc tags\nversion = torchscan.__version__\nrelease = torchscan.__version__ + \"-git\"\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.mathjax\",\n    \"sphinxemoji.sphinxemoji\",  # cf. https://sphinxemojicodes.readthedocs.io/en/stable/\n    \"sphinx_copybutton\",\n]\n\nnapoleon_use_ivar = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\n\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = \"friendly\"\npygments_dark_style = \"monokai\"\nhighlight_language = \"python3\"\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"furo\"\n\nhtml_title = \"Torchscan\"\nhtml_logo = \"_static/images/logo.png\"\nhtml_favicon = \"_static/images/favicon.ico\"\nlanguage = \"en\"\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\nhtml_theme_options = {\n    \"footer_icons\": [\n        {\n            \"name\": \"GitHub\",\n            \"url\": \"https://github.com/frgfm/torch-scan\",\n            \"html\": \"\"\"\n                <svg stroke=\"currentColor\" fill=\"currentColor\" stroke-width=\"0\" viewBox=\"0 0 16 16\">\n                    <path fill-rule=\"evenodd\" d=\"M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z\"></path>\n                </svg>\n            \"\"\",\n            \"class\": \"\",\n        },\n    ],\n    \"source_repository\": \"https://github.com/frgfm/torch-scan/\",\n    \"source_branch\": \"main\",\n    \"source_directory\": \"docs/source/\",\n    \"sidebar_hide_name\": True,\n}\n\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n\n\n# Add googleanalytics id\n# ref: https://github.com/orenhecht/googleanalytics/blob/master/sphinxcontrib/googleanalytics.py\ndef add_ga_javascript(app, pagename, templatename, context, doctree):\n    metatags = context.get(\"metatags\", \"\")\n    metatags += \"\"\"\n    <!-- Global site tag (gtag.js) - Google Analytics -->\n<script async src=\"https://www.googletagmanager.com/gtag/js?id={0}\"></script>\n<script>\n  window.dataLayer = window.dataLayer || [];\n  function gtag(){{dataLayer.push(arguments);}}\n  gtag('js', new Date());\n  gtag('config', '{0}');\n</script>\n    \"\"\".format(app.config.googleanalytics_id)\n    context[\"metatags\"] = metatags\n\n\ndef setup(app):\n    app.add_config_value(\"googleanalytics_id\", \"UA-148140560-3\", \"html\")\n    app.add_css_file(\"css/custom.css\")\n    app.add_js_file(\"js/custom.js\")\n    app.connect(\"html-page-context\", add_ga_javascript)\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": "**************************************\nTorchScan: inspect your PyTorch models\n**************************************\n\nThe :mod:`torchscan` package provides tools for analyzing your PyTorch modules and models. Additionally to performance benchmarks, a comprehensive architecture comparison require some insights in the model complexity, its usage of computational and memory resources.\n\n\nThis project is meant for:\n\n* |:zap:| **exploration**: easily assess the influence of your architecture on resource consumption\n* |:woman_scientist:| **research**: quickly implement your own ideas to mitigate latency\n\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Getting Started\n   :hidden:\n\n   installing\n\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Package Reference\n   :hidden:\n\n   torchscan\n   modules\n   process\n   utils\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Notes\n   :hidden:\n\n   changelog\n\n\nSupported layers\n^^^^^^^^^^^^^^^^\n\nHere is the list of supported layers for FLOPS, MACs, DMAs and receptive field computation:\n\nNon-linear activations\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n* `torch.nn.ReLU <https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_\n* `torch.nn.ELU <https://pytorch.org/docs/stable/generated/torch.nn.ELU.html>`_\n* `torch.nn.LeakyReLU <https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html>`_\n* `torch.nn.ReLU6 <https://pytorch.org/docs/stable/generated/torch.nn.ReLU6.html>`_\n* `torch.nn.Tanh <https://pytorch.org/docs/stable/generated/torch.nn.Tanh.html>`_\n* `torch.nn.Sigmoid <https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html>`_\n\nLinear layers\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n* `torch.nn.Identity <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_\n* `torch.nn.Linear <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_\n\nConvolutions\n\"\"\"\"\"\"\"\"\"\"\"\"\n\n* `torch.nn.Conv1d <https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html>`_\n* `torch.nn.Conv2d <https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html>`_\n* `torch.nn.Conv3d <https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html>`_\n* `torch.nn.ConvTranspose1d <https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html>`_\n* `torch.nn.ConvTranspose2d <https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html>`_\n* `torch.nn.ConvTranspose3d <https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html>`_\n\nPooling\n\"\"\"\"\"\"\"\n\n* `torch.nn.MaxPool1d <https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html>`_\n* `torch.nn.MaxPool2d <https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html>`_\n* `torch.nn.MaxPool3d <https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html>`_\n* `torch.nn.AvgPool1d <https://pytorch.org/docs/stable/generated/torch.nn.AvgPool1d.html>`_\n* `torch.nn.AvgPool2d <https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html>`_\n* `torch.nn.AvgPool3d <https://pytorch.org/docs/stable/generated/torch.nn.AvgPool3d.html>`_\n* `torch.nn.AdaptiveMaxPool1d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool1d.html>`_\n* `torch.nn.AdaptiveMaxPool2d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool2d.html>`_\n* `torch.nn.AdaptiveMaxPool3d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool3d.html>`_\n* `torch.nn.AdaptiveAvgPool1d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool1d.html>`_\n* `torch.nn.AdaptiveAvgPool2d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html>`_\n* `torch.nn.AdaptiveAvgPool3d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool3d.html>`_\n\nNormalization\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n* `torch.nn.BatchNorm1d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html>`_\n* `torch.nn.BatchNorm2d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_\n* `torch.nn.BatchNorm3d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html>`_\n\nOther\n\"\"\"\"\"\n\n* `torch.nn.Flatten <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_\n* `torch.nn.Dropout <https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_\n\n\n*Please note that the functional API of PyTorch is not supported.*\n"
  },
  {
    "path": "docs/source/installing.rst",
    "content": "\n************\nInstallation\n************\n\nThis library requires `Python <https://www.python.org/downloads/>`_ 3.6 or higher.\n\nVia Python Package\n==================\n\nInstall the last stable release of the package using `pip <https://pip.pypa.io/en/stable/installation/>`_:\n\n.. code:: bash\n\n    pip install torchscan\n\n\nVia Conda\n=========\n\nInstall the last stable release of the package using `conda <https://docs.conda.io/en/latest/>`_:\n\n.. code:: bash\n\n    conda install -c frgfm torchscan\n\n\nVia Git\n=======\n\nInstall the library in developer mode:\n\n.. code:: bash\n\n    git clone https://github.com/frgfm/torch-scan.git\n    pip install -e torch-scan/.\n"
  },
  {
    "path": "docs/source/modules.rst",
    "content": "torchscan.modules\n=================\n\nThe modules subpackage contains tools for inspection of modules.\n\n.. currentmodule:: torchscan.modules\n\n\nFLOPs\n-----\nRelated to the number of floating point operations performed during model inference.\n\n.. autofunction:: module_flops\n\n\nMACs\n-----\nRelated to the number of multiply-accumulate operations performed during model inference\n\n.. autofunction:: module_macs\n\n\nDMAs\n----\nRelated to the number of direct memory accesses during model inference\n\n.. autofunction:: module_dmas\n\n\nReceptive field\n---------------\nRelated to the effective receptive field of a layer\n\n.. autofunction:: module_rf\n"
  },
  {
    "path": "docs/source/process.rst",
    "content": "torchscan.process\n=================\n\nThe process subpackage contains tools regarding active Python processes.\n\nThe following models are available:\n\n.. automodule:: torchscan.process\n.. currentmodule:: torchscan.process\n\n\n.. autofunction:: get_process_gpu_ram\n"
  },
  {
    "path": "docs/source/torchscan.rst",
    "content": "torchscan\n=========\n\n\n.. currentmodule:: torchscan\n\n\nCrawler\n~~~~~~~\n\n.. autofunction:: crawl_module\n.. autofunction:: summary\n"
  },
  {
    "path": "docs/source/utils.rst",
    "content": "torchscan.utils\n===============\n\n.. currentmodule:: torchscan.utils\n\n.. autofunction:: format_info\n\n.. autofunction:: aggregate_info\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"torchscan\"\ndescription = \"Useful information about your Pytorch module\"\nauthors = [\n    {name = \"François-Guillaume Fernandez\", email = \"fg-feedback@protonmail.com\"}\n]\nreadme = \"README.md\"\nrequires-python = \">=3.8,<4\"\nlicense = {file = \"LICENSE\"}\nkeywords = [\"pytorch\", \"deep learning\", \"summary\", \"memory\", \"ram\"]\nclassifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Science/Research\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Natural Language :: English\",\n    \"Operating System :: OS Independent\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3.8\",\n    \"Programming Language :: Python :: 3.9\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Topic :: Scientific/Engineering\",\n    \"Topic :: Scientific/Engineering :: Mathematics\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n]\ndynamic = [\"version\"]\ndependencies = [\n    \"torch>=2.0.0,<3.0.0\",\n]\n\n[project.optional-dependencies]\ntest = [\n    \"pytest>=7.3.2\",\n    \"pytest-cov>=3.0.0,<5.0.0\",\n    \"pytest-pretty>=1.0.0,<2.0.0\",\n]\nquality = [\n    \"ruff==0.6.4\",\n    \"mypy==1.14.0\",\n    \"pre-commit>=3.0.0,<4.0.0\",\n]\ndocs = [\n    \"sphinx>=3.0.0,!=3.5.0\",\n    \"furo>=2022.3.4\",\n    \"sphinxemoji>=0.1.8\",\n    \"sphinx-copybutton>=0.3.1\",\n    # Indirect deps\n    # cf. https://github.com/readthedocs/readthedocs.org/issues/9038\n    \"Jinja2<3.1\",\n]\ndev = [\n    # test\n    \"pytest>=7.3.2\",\n    \"pytest-cov>=3.0.0,<5.0.0\",\n    \"pytest-pretty>=1.0.0,<2.0.0\",\n    # style\n    \"ruff==0.6.4\",\n    \"mypy==1.14.0\",\n    \"pre-commit>=3.0.0,<4.0.0\",\n    # docs\n    \"sphinx>=3.0.0,!=3.5.0\",\n    \"furo>=2022.3.4\",\n    \"sphinxemoji>=0.1.8\",\n    \"sphinx-copybutton>=0.3.1\",\n    \"Jinja2<3.1\",\n]\n\n[project.urls]\ndocumentation = \"https://frgfm.github.io/torch-scan\"\nrepository = \"https://github.com/frgfm/torch-scan\"\ntracker = \"https://github.com/frgfm/torch-scan/issues\"\nchangelog = \"https://frgfm.github.io/torch-scan/latest/changelog.html\"\n\n[tool.setuptools]\nzip-safe = true\n\n[tool.setuptools.packages.find]\nexclude = [\"docs*\", \"scripts*\", \"tests*\"]\n\n[tool.pytest.ini_options]\ntestpaths = [\"torchscan/\"]\n\n[tool.coverage.run]\nsource = [\"torchscan/\"]\n\n[tool.ruff]\nline-length = 120\ntarget-version = \"py311\"\npreview = true\n\n[tool.ruff.lint]\nselect = [\n    \"F\",  # pyflakes\n    \"E\",  # pycodestyle errors\n    \"W\",  # pycodestyle warnings\n    \"I\",  # isort\n    \"N\",  # pep8-naming\n    \"D101\", \"D103\",  # pydocstyle missing docstring in public function/class\n    \"D201\",\"D202\",\"D207\",\"D208\",\"D214\",\"D215\",\"D300\",\"D301\",\"D417\", \"D419\",  # pydocstyle\n    \"YTT\",  # flake8-2020\n    \"ANN\",  # flake8-annotations\n    \"ASYNC\",  # flake8-async\n    \"S\",  # flake8-bandit\n    \"BLE\",  # flake8-blind-except\n    \"B\",  # flake8-bugbear\n    \"A\",  # flake8-builtins\n    \"COM\",  # flake8-commas\n    \"CPY\",  # flake8-copyright\n    \"C4\",  # flake8-comprehensions\n    \"T10\",  # flake8-debugger\n    \"ISC\",  # flake8-implicit-str-concat\n    \"ICN\",  # flake8-import-conventions\n    \"LOG\",  # flake8-logging\n    \"PIE\",  # flake8-pie\n    \"T20\",  # flake8-print\n    \"PYI\",  # flake8-pyi\n    \"PT\",  # flake8-pytest-style\n    \"Q\",    # flake8-quotes\n    \"RET\",  # flake8-return\n    \"SLF\",  # flake8-self\n    \"SIM\",  # flake8-simplify\n    \"ARG\",  # flake8-unused-arguments\n    \"PTH\",  # flake8-use-pathlib\n    \"PERF\",  # perflint\n    \"NPY\",  # numpy\n    \"FAST\",  # fastapi\n    \"FURB\",  # refurb\n    \"RUF\",  # ruff specific\n    \"N\",  # pep8-naming\n]\nignore = [\n    \"E501\",  # line too long, handled by black\n    \"B008\",  # do not perform function calls in argument defaults\n    \"B904\",  # raise from\n    \"C901\",  # too complex\n    \"F403\",  # star imports\n    \"E731\",  # lambda assignment\n    \"C416\",  # list comprehension to list()\n    \"ANN101\",  # missing type annotations on self\n    \"ANN102\",  # missing type annotations on cls\n    \"ANN002\",  # missing type annotations on *args\n    \"ANN003\",  # missing type annotations on **kwargs\n    \"COM812\",  # trailing comma missing\n    \"N812\",  # lowercase imported as non-lowercase\n    \"ISC001\",  # implicit string concatenation (handled by format)\n    \"ANN401\",  # Dynamically typed expressions (typing.Any) are disallowed\n    \"SLF001\",  # Private member accessed\n]\nexclude = [\".git\"]\n\n[tool.ruff.lint.flake8-quotes]\ndocstring-quotes = \"double\"\n\n[tool.ruff.lint.isort]\nknown-first-party = [\"torchscan\", \"app\"]\nknown-third-party = [\"torch\", \"torchvision\"]\n\n[tool.ruff.lint.per-file-ignores]\n\"**/__init__.py\" = [\"I001\", \"F401\", \"CPY001\"]\n\"scripts/**.py\" = [\"D\", \"T201\", \"N812\", \"S101\", \"ANN\"]\n\".github/**.py\" = [\"D\", \"T201\", \"S602\", \"S101\", \"ANN\"]\n\"docs/**.py\" = [\"E402\", \"D103\", \"ANN\", \"A001\", \"ARG001\"]\n\"tests/**.py\" = [\"D101\", \"D103\", \"CPY001\", \"S101\", \"PT011\", \"ANN\", \"SLF001\"]\n\"demo/**.py\" = [\"D103\", \"ANN\"]\n\"setup.py\" = [\"T201\"]\n\"torchscan/process/memory.py\" = [\"S60\"]\n\n[tool.ruff.format]\nquote-style = \"double\"\nindent-style = \"space\"\n\n\n[tool.mypy]\npython_version = \"3.11\"\nfiles = \"torchscan/\"\nshow_error_codes = true\npretty = true\nwarn_unused_ignores = true\nwarn_redundant_casts = true\nno_implicit_optional = true\ndisallow_untyped_calls = true\ncheck_untyped_defs = true\nimplicit_reexport = false\ndisallow_untyped_defs = true\nexplicit_package_bases = true\n"
  },
  {
    "path": "scripts/benchmark.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\n\"\"\"\nTorchvision benchmark\n\"\"\"\n\nimport torch\nfrom torchvision import models\n\nfrom torchscan import crawl_module\n\nTORCHVISION_MODELS = [\n    \"alexnet\",\n    \"googlenet\",\n    \"vgg11\",\n    \"vgg11_bn\",\n    \"vgg13\",\n    \"vgg13_bn\",\n    \"vgg16\",\n    \"vgg16_bn\",\n    \"vgg19\",\n    \"vgg19_bn\",\n    \"resnet18\",\n    \"resnet34\",\n    \"resnet50\",\n    \"resnet101\",\n    \"resnet152\",\n    \"inception_v3\",\n    \"squeezenet1_0\",\n    \"squeezenet1_1\",\n    \"wide_resnet50_2\",\n    \"wide_resnet101_2\",\n    \"densenet121\",\n    \"densenet161\",\n    \"densenet169\",\n    \"densenet201\",\n    \"resnext50_32x4d\",\n    \"resnext101_32x8d\",\n    \"mobilenet_v2\",\n    \"shufflenet_v2_x0_5\",\n    \"shufflenet_v2_x1_0\",\n    \"shufflenet_v2_x1_5\",\n    \"shufflenet_v2_x2_0\",\n    \"mnasnet0_5\",\n    \"mnasnet0_75\",\n    \"mnasnet1_0\",\n    \"mnasnet1_3\",\n]\n\n\ndef main():\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    margin = 4\n    headers = [\"Model\", \"Params (M)\", \"FLOPs (G)\", \"MACs (G)\", \"DMAs (G)\", \"RF\"]\n    max_w = [20, 10, 10, 10, 10, 10]\n\n    info_str = [(\" \" * margin).join([f\"{col_name:<{col_w}}\" for col_name, col_w in zip(headers, max_w, strict=False)])]\n    info_str.append(\"-\" * len(info_str[0]))\n    print(\"\\n\".join(info_str))\n    for name in TORCHVISION_MODELS:\n        model = models.__dict__[name]().eval().to(device)\n        dsize = (3, 224, 224)\n        if \"inception\" in name:\n            dsize = (3, 299, 299)\n        model_info = crawl_module(model, dsize)\n\n        tot_params = sum(layer[\"grad_params\"] + layer[\"nograd_params\"] for layer in model_info[\"layers\"])\n        tot_flops = sum(layer[\"flops\"] for layer in model_info[\"layers\"])\n        tot_macs = sum(layer[\"macs\"] for layer in model_info[\"layers\"])\n        tot_dmas = sum(layer[\"dmas\"] for layer in model_info[\"layers\"])\n        rf = model_info[\"layers\"][0][\"rf\"]\n        print(\n            f\"{name:<{max_w[0]}} | {tot_params / 1e6:<{max_w[1]}.2f} | {tot_flops / 1e9:<{max_w[2]}.2f} | \"\n            f\"{tot_macs / 1e9:<{max_w[3]}.2f} | {tot_dmas / 1e9:<{max_w[4]}.2f} | {rf:<{max_w[5]}.0f}\"\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\n\nimport os\nfrom pathlib import Path\n\nfrom setuptools import setup\n\nPKG_NAME = \"torchscan\"\nVERSION = os.getenv(\"BUILD_VERSION\", \"0.2.0.dev0\")\n\n\nif __name__ == \"__main__\":\n    print(f\"Building wheel {PKG_NAME}-{VERSION}\")\n\n    # Dynamically set the __version__ attribute\n    cwd = Path(__file__).parent.absolute()\n    with cwd.joinpath(\"torchscan\", \"version.py\").open(\"w\", encoding=\"utf-8\") as f:\n        f.write(f\"__version__ = '{VERSION}'\\n\")\n\n    setup(name=PKG_NAME, version=VERSION)\n"
  },
  {
    "path": "tests/test_crawler.py",
    "content": "import io\nimport sys\nfrom collections import OrderedDict\n\nimport pytest\nimport torch.nn as nn\n\nfrom torchscan import crawler\n\n\ndef test_apply():\n    multi_convs = nn.Sequential(nn.Conv2d(16, 32, 3), nn.Conv2d(32, 64, 3))\n    mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs)\n\n    # Tag module attributes\n    def tag_name(mod, name):\n        mod.__depth__ = len(name.split(\".\")) - 1\n        mod.__name__ = name.rpartition(\".\")[-1]\n\n    crawler.apply(mod, tag_name)\n\n    assert mod[1][1].__depth__ == 2\n    assert mod[1][1].__name__ == \"1\"\n\n\ndef test_crawl_module():\n    mod = nn.Conv2d(3, 8, 3)\n\n    res = crawler.crawl_module(mod, (3, 32, 32))\n    assert isinstance(res, dict)\n    assert res[\"overall\"][\"grad_params\"] == 224\n    assert res[\"layers\"][0][\"output_shape\"] == (-1, 8, 30, 30)\n\n\ndef test_summary():\n    mod = nn.Conv2d(3, 8, 3)\n\n    # Redirect stdout with StringIO object\n    captured_output = io.StringIO()\n    sys.stdout = captured_output\n    crawler.summary(mod, (3, 32, 32))\n    # Reset redirect.\n    sys.stdout = sys.__stdout__\n    assert captured_output.getvalue().split(\"\\n\")[7] == \"Total params: 224\"\n\n    # Check receptive field\n    captured_output = io.StringIO()\n    sys.stdout = captured_output\n    crawler.summary(mod, (3, 32, 32), receptive_field=True)\n    # Reset redirect.\n    sys.stdout = sys.__stdout__\n    assert captured_output.getvalue().split(\"\\n\")[1].rpartition(\"  \")[-1] == \"Receptive field\"\n    assert captured_output.getvalue().split(\"\\n\")[3].split()[-1] == \"3\"\n    # Check effective stats\n    captured_output = io.StringIO()\n    sys.stdout = captured_output\n    crawler.summary(mod, (3, 32, 32), receptive_field=True, effective_rf_stats=True)\n    # Reset redirect.\n    sys.stdout = sys.__stdout__\n    assert captured_output.getvalue().split(\"\\n\")[1].rpartition(\"  \")[-1] == \"Effective padding\"\n    assert captured_output.getvalue().split(\"\\n\")[3].split()[-1] == \"0\"\n\n    # Max depth > model hierarchy\n    with pytest.raises(ValueError):\n        crawler.summary(mod, (3, 32, 32), max_depth=1)\n\n    mod = nn.Sequential(\n        OrderedDict([\n            (\"features\", nn.Sequential(nn.Conv2d(3, 8, 3), nn.ReLU(inplace=True))),\n            (\"pool\", nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1))),\n            (\"classifier\", nn.Linear(8, 1)),\n        ])\n    )\n\n    captured_output = io.StringIO()\n    sys.stdout = captured_output\n    crawler.summary(mod, (3, 32, 32), max_depth=1)\n    # Reset redirect.\n    sys.stdout = sys.__stdout__\n    assert captured_output.getvalue().split(\"\\n\")[4].startswith(\"├─features \")\n"
  },
  {
    "path": "tests/test_modules.py",
    "content": "import pytest\nimport torch\nfrom torch import nn\n\nfrom torchscan import modules\n\n\nclass MyModule(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n\ndef test_module_flops_warning():\n    with pytest.warns(UserWarning):\n        modules.module_flops(MyModule(), None, None)\n\n\n@pytest.mark.parametrize(\n    (\"mod\", \"input_shape\", \"output_shape\", \"expected_val\"),\n    [\n        # Check for unknown module that it returns 0 and throws a warning\n        (MyModule(), (1,), (1,), 0),\n        # Fully-connected\n        (nn.Linear(8, 4), (1, 8), (1, 4), 4 * (2 * 8 - 1) + 4),\n        (nn.Linear(8, 4, bias=False), (1, 8), (1, 4), 4 * (2 * 8 - 1)),\n        (nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 2 * (4 * (2 * 8 - 1) + 4)),\n        # Activations\n        (nn.Identity(), (1, 8), (1, 8), 0),\n        (nn.Flatten(), (1, 8), (1, 8), 0),\n        (nn.ReLU(), (1, 8), (1, 8), 8),\n        (nn.ELU(), (1, 8), (1, 8), 48),\n        (nn.LeakyReLU(), (1, 8), (1, 8), 32),\n        (nn.ReLU6(), (1, 8), (1, 8), 16),\n        (nn.Tanh(), (1, 8), (1, 8), 48),\n        (nn.Sigmoid(), (1, 8), (1, 8), 32),\n        # BN\n        (nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 144 + 32 + 32 * 3 + 48),\n        # Pooling\n        (nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),\n        (nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),\n        (nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),\n        (nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),\n        (nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),\n        (nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),\n        # Dropout\n        (nn.Dropout(), (1, 8), (1, 8), 8),\n        (nn.Dropout(p=0), (1, 8), (1, 8), 0),\n        # Conv\n        (nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 388800),\n        (nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 499408),\n    ],\n)\ndef test_module_flops(mod, input_shape, output_shape, expected_val):\n    assert modules.module_flops(mod, (torch.zeros(input_shape),), torch.zeros(output_shape)) == expected_val\n\n\ndef test_transformer_flops():\n    mod = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=3)\n    src = torch.rand((10, 16, 64))\n    tgt = torch.rand((20, 16, 64))\n    assert modules.module_flops(mod, (src, tgt), mod(src, tgt)) == 774952841\n\n\ndef test_module_macs_warning():\n    with pytest.warns(UserWarning):\n        modules.module_macs(MyModule(), None, None)\n\n\n@pytest.mark.parametrize(\n    (\"mod\", \"input_shape\", \"output_shape\", \"expected_val\"),\n    [\n        # Check for unknown module that it returns 0 and throws a warning\n        (MyModule(), (1,), (1,), 0),\n        # Fully-connected\n        (nn.Linear(8, 4), (1, 8), (1, 4), 8 * 4),\n        (nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 8 * 4 * 2),\n        # Activations\n        (nn.ReLU(), (1, 8), (1, 8), 0),\n        # BN\n        (nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 64 + 24 + 56 + 32),\n        # Pooling\n        (nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),\n        (nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),\n        (nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),\n        (nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),\n        (nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),\n        (nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),\n        # Dropout\n        (nn.Dropout(), (1, 8), (1, 8), 0),\n        # Conv\n        (nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 194400),\n        (nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 249704),\n    ],\n)\ndef test_module_macs(mod, input_shape, output_shape, expected_val):\n    assert modules.module_macs(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val\n\n\ndef test_module_dmas_warning():\n    with pytest.warns(UserWarning):\n        modules.module_dmas(MyModule(), None, None)\n\n\n@pytest.mark.parametrize(\n    (\"mod\", \"input_shape\", \"output_shape\", \"expected_val\"),\n    [\n        # Check for unknown module that it returns 0 and throws a warning\n        (MyModule(), (1,), (1,), 0),\n        # Fully-connected\n        (nn.Linear(8, 4), (1, 8), (1, 4), 4 * (8 + 1) + 8 + 4),\n        (nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 4 * (8 + 1) + 2 * (8 + 4)),\n        # Activations\n        (nn.Identity(), (1, 8), (1, 8), 8),\n        (nn.Flatten(), (1, 8), (1, 8), 16),\n        (nn.ReLU(), (1, 8), (1, 8), 8 * 2),\n        (nn.ReLU(inplace=True), (1, 8), (1, 8), 8),\n        (nn.ELU(), (1, 8), (1, 8), 17),\n        (nn.Tanh(), (1, 8), (1, 8), 24),\n        (nn.Sigmoid(), (1, 8), (1, 8), 16),\n        # BN\n        (nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 32 + 17 + 16 + 1 + 17 + 32),\n        # Pooling\n        (nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),\n        (nn.MaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),\n        (nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),\n        (nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),\n        # Dropout\n        (nn.Dropout(), (1, 8), (1, 8), 17),\n        # Conv\n        (nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 201824),\n        (nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 259178),\n    ],\n)\ndef test_module_dmas(mod, input_shape, output_shape, expected_val):\n    assert modules.module_dmas(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val\n\n\n# @torch.no_grad()\n# def test_module_rf(self):\n\n#     # Check for unknown module that it returns 0 and throws a warning\n#     self.assertEqual(modules.module_rf(MyModule(), None, None), (1, 1, 0))\n#     self.assertWarns(UserWarning, modules.module_rf, MyModule(), None, None)\n\n#     # Common unit tests\n#     # Linear\n#     self.assertEqual(modules.module_rf(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))),\n#                      (1, 1, 0))\n#     # Activation\n#     self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n#     self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n#     self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n#     self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n#     self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n#     self.assertEqual(modules.module_rf(nn.Tanh(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n#     # Conv\n#     input_t = torch.rand((1, 3, 32, 32))\n#     mod = nn.Conv2d(3, 8, 3)\n#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (3, 1, 0))\n#     # Check for dilation support\n#     mod = nn.Conv2d(3, 8, 3, dilation=2)\n#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (5, 1, 0))\n#     # ConvTranspose\n#     mod = nn.ConvTranspose2d(3, 8, 3)\n#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (-3, 1, 0))\n#     # BN\n#     self.assertEqual(modules.module_rf(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))),\n#                      (1, 1, 0))\n\n#     # Pooling\n#     self.assertEqual(modules.module_rf(nn.MaxPool2d((2, 2)),\n#                                        torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))),\n#                      (2, 2, 0))\n#     self.assertEqual(modules.module_rf(nn.AdaptiveMaxPool2d((2, 2)),\n#                                        torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))),\n#                      (2, 2, 0))\n\n#     # Dropout\n#     self.assertEqual(modules.module_rf(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))\n"
  },
  {
    "path": "tests/test_process.py",
    "content": "import os\n\nimport torch\n\nfrom torchscan import process\n\n\ndef test_get_process_gpu_ram():\n    if torch.cuda.is_initialized:\n        assert process.get_process_gpu_ram(os.getpid()) >= 0\n    else:\n        assert process.get_process_gpu_ram(os.getpid()) == 0\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "import pytest\n\nfrom torchscan import utils\n\n\ndef test_format_name():\n    name = \"mymodule\"\n    assert utils.format_name(name) == name\n    assert utils.format_name(name, depth=1) == f\"├─{name}\"\n    assert utils.format_name(name, depth=3) == f\"|    |    └─{name}\"\n\n\ndef test_wrap_string():\n    example = \".\".join([\"a\" for _ in range(10)])\n    max_len = 10\n    wrap = \"[...]\"\n\n    assert utils.wrap_string(example, max_len, mode=\"end\") == example[: max_len - len(wrap)] + wrap\n    assert utils.wrap_string(example, max_len, mode=\"mid\") == f\"{example[: max_len - 2 - len(wrap)]}{wrap}.a\"\n    assert utils.wrap_string(example, len(example), mode=\"end\") == example\n    with pytest.raises(ValueError):\n        _ = utils.wrap_string(example, max_len, mode=\"test\")\n\n\n@pytest.mark.parametrize(\n    (\"input_val\", \"num_val\", \"unit\"),\n    [\n        (3e14, 300, \"T\"),\n        (3e10, 30, \"G\"),\n        (3e7, 30, \"M\"),\n        (15e3, 15, \"k\"),\n        (500, 500, \"\"),\n    ],\n)\ndef test_unit_scale(input_val, num_val, unit):\n    assert utils.unit_scale(input_val) == (num_val, unit)\n"
  },
  {
    "path": "torchscan/__init__.py",
    "content": "from contextlib import suppress\nfrom torchscan import modules, process, utils\nfrom torchscan.crawler import *\n\nwith suppress(ImportError):\n    from .version import __version__\n"
  },
  {
    "path": "torchscan/crawler.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nimport os\nfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union\n\nimport torch\nfrom torch.nn import Module\n\nfrom .modules import module_dmas, module_flops, module_macs, module_rf\nfrom .process import get_process_gpu_ram\nfrom .utils import aggregate_info, format_info\n\n__all__ = [\"crawl_module\", \"summary\"]\n\n\ndef apply(module: Module, fn: Callable[[Module, str], None], name: Optional[str] = None) -> None:\n    \"\"\"Modified version of `torch.nn.Module.apply` method\n\n    Args:\n        module: target module\n        fn: function to apply to each module\n        name: name of the current module\n    \"\"\"\n    if name is None:\n        name = module.__class__.__name__.lower()\n    fn(module, name)\n    for n, m in module.named_children():\n        apply(m, fn, f\"{name}.{n}\")\n\n\ndef crawl_module(\n    module: Module,\n    input_shape: Union[List[Tuple[int, ...]], Tuple[int, ...]],\n    dtype: Optional[Union[torch.dtype, Iterable[torch.dtype]]] = None,\n) -> Dict[str, Any]:\n    \"\"\"Retrieves module information for an expected input tensor shape\n\n    >>> import torch.nn as nn\n    >>> from torchscan import summary\n    >>> mod = nn.Conv2d(3, 8, 3)\n    >>> module_info = crawl_module(mod, (3, 224, 224))\n\n    Args:\n        module: module to inspect\n        input_shape: expected input shapes\n        dtype: data type of each input argument to the module\n    Returns:\n        layer and overhead information\n    \"\"\"\n    # Get device and data types from model\n    p = next(module.parameters())\n    device = p.device\n\n    cuda_overhead, framework_overhead = 0.0, 0.0\n    if torch.cuda.is_available():\n        # Process RAM - allocator RAM\n        cuda_overhead = get_process_gpu_ram(os.getpid()) - (torch.cuda.memory_reserved() / 1024**2)\n        # Allocator RAM - Used RAM\n        framework_overhead = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024**2\n\n    # input\n    if not isinstance(input_shape, list):\n        input_shape = [input_shape]\n    if dtype is None:\n        dtype = p.data.dtype\n    if isinstance(dtype, torch.dtype):\n        dtype = [dtype] * len(input_shape)\n    # Tensor arguments\n    input_ts = [\n        torch.rand(1, *in_shape).to(dtype=_dtype, device=device)\n        for in_shape, _dtype in zip(input_shape, dtype, strict=False)\n    ]\n\n    pre_fw_handles, post_fw_handles = [], []\n    pre_hook_tracker: Dict[int, Any] = {}\n    post_hook_tracker: Dict[int, Any] = {}\n\n    # Hook definition\n    def _hook_info(module: Module, name: str) -> None:\n        def _pre_hook(module: Module, inp: torch.Tensor) -> None:\n            \"\"\"Pre-forward hook\"\"\"\n            # Check that another hook has not been triggered at this forward stage\n            if not pre_hook_tracker[id(module)][\"is_used\"] and (\n                pre_hook_tracker[id(module)][\"target\"] == pre_hook_tracker[id(module)][\"current\"]\n            ):\n                # Add information\n                # Params\n                grad_params, nograd_params, param_size = 0, 0, 0\n                num_buffers, buffer_size = 0, 0\n                is_shared = False\n                if not any(module.children()):\n                    # Parameters\n                    for p in module.parameters():\n                        if id(p) not in param_ids:\n                            if p.requires_grad:\n                                grad_params += p.data.numel()\n                            else:\n                                nograd_params += p.data.numel()\n                            param_size += p.data.numel() * p.data.element_size()\n                            param_ids.append(id(p))\n                        else:\n                            is_shared = True\n                    # Buffers\n                    for b in module.buffers():\n                        if id(b) not in param_ids:\n                            num_buffers += b.numel()\n                            buffer_size += b.numel() * b.element_size()\n                            param_ids.append(id(b))\n                        else:\n                            is_shared = True\n\n                if call_idxs.get(id(module)) is None:\n                    call_idxs[id(module)] = [len(info)]\n                else:\n                    call_idxs[id(module)].append(len(info))\n\n                info.append({\n                    \"name\": name.rpartition(\".\")[-1],\n                    \"depth\": len(name.split(\".\")) - 1,\n                    \"type\": module.__class__.__name__,\n                    \"input_shape\": (-1, *inp[0][0].shape[1:]),\n                    \"output_shape\": None,\n                    \"grad_params\": grad_params,\n                    \"nograd_params\": nograd_params,\n                    \"param_size\": param_size,\n                    \"num_buffers\": num_buffers,\n                    \"buffer_size\": buffer_size,\n                    \"flops\": 0,\n                    \"macs\": 0,\n                    \"dmas\": 0,\n                    \"rf\": 1,\n                    \"s\": 1,\n                    \"p\": 0,\n                    \"is_shared\": is_shared,\n                    \"is_leaf\": not any(module.children()),\n                })\n                # Mark the next hook for execution\n                pre_hook_tracker[id(module)][\"target\"] += 1\n                # Current pass already used one of the hooks\n                pre_hook_tracker[id(module)][\"is_used\"] = True\n            pre_hook_tracker[id(module)][\"current\"] += 1\n            # All the hooks have been checked, reset the temporary values\n            if pre_hook_tracker[id(module)][\"current\"] == len(module._forward_pre_hooks):\n                pre_hook_tracker[id(module)][\"current\"] = 0\n                pre_hook_tracker[id(module)][\"is_used\"] = False\n\n        def _fwd_hook(module: Module, inputs: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None:\n            \"\"\"Post-forward hook\"\"\"\n            # Check that another hook has not been triggered at this forward stage\n            if not post_hook_tracker[id(module)][\"is_used\"] and (\n                post_hook_tracker[id(module)][\"target\"] == post_hook_tracker[id(module)][\"current\"]\n            ):\n                # Write information\n                # Retrieve forward index\n                if len(call_idxs[id(module)]) == 1:\n                    fw_idx = call_idxs[id(module)][0]\n                else:\n                    # The first dictionary with output_shape=None is the correct one\n                    for _idx in call_idxs[id(module)]:\n                        if info[_idx][\"output_shape\"] is None:\n                            fw_idx = _idx\n                            break\n\n                if any(module.children()):\n                    tot_flops, tot_macs, tot_dmas = 0, 0, 0\n                    current_rf, current_stride, current_padding = 1.0, 1.0, 0.0\n                else:\n                    # Compute stats for standalone layers\n                    tot_flops = module_flops(module, inputs, out)\n                    tot_macs = module_macs(module, inputs[0], out)\n                    tot_dmas = module_dmas(module, inputs[0], out)\n                    current_rf, current_stride, current_padding = module_rf(module, inputs[0], out)\n\n                # Update layer information\n                info[fw_idx][\"output_shape\"] = (-1, *out.shape[1:])\n                # Add them, since some modules can be used several times\n                info[fw_idx][\"flops\"] = tot_flops\n                info[fw_idx][\"macs\"] = tot_macs\n                info[fw_idx][\"dmas\"] = tot_dmas\n                # Compute receptive field\n                info[fw_idx][\"rf\"] = current_rf\n                info[fw_idx][\"s\"] = current_stride\n                info[fw_idx][\"p\"] = current_padding\n\n                # Mark the next hook for execution\n                post_hook_tracker[id(module)][\"target\"] += 1\n                # Current pass already used one of the hooks\n                post_hook_tracker[id(module)][\"is_used\"] = True\n            post_hook_tracker[id(module)][\"current\"] += 1\n            # All the hooks have been checked, reset the temporary values\n            if post_hook_tracker[id(module)][\"current\"] == len(module._forward_pre_hooks):\n                post_hook_tracker[id(module)][\"current\"] = 0\n                post_hook_tracker[id(module)][\"is_used\"] = False\n\n        pre_fw_handles.append(module.register_forward_pre_hook(_pre_hook))  # type: ignore[arg-type]\n        post_fw_handles.append(module.register_forward_hook(_fwd_hook))\n        # Handle modules that are used multiple times (with several hooks)\n        pre_hook_tracker[id(module)] = {\"current\": 0, \"target\": 0, \"is_used\": False}\n        post_hook_tracker[id(module)] = {\"current\": 0, \"target\": 0, \"is_used\": False}\n\n    # Hook model\n    info: List[Dict[str, Any]] = []\n    param_ids: List[int] = []\n    call_idxs: Dict[int, List[int]] = {}\n    apply(module, _hook_info)\n\n    # Forward\n    with torch.no_grad():\n        module(*input_ts)\n\n    # Removes all hooks using their handles\n    for handle in pre_fw_handles:\n        handle.remove()\n    for handle in post_fw_handles:\n        handle.remove()\n\n    reserved_ram, diff_ram = 0.0, 0.0\n    if torch.cuda.is_available():\n        reserved_ram = torch.cuda.memory_reserved() / 1024**2\n        diff_ram = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024**2\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n\n    grad_params, nograd_params, param_size = 0, 0, 0\n    num_buffers, buffer_size = 0, 0\n    for p in module.parameters():\n        if p.requires_grad:\n            grad_params += p.data.numel()\n        else:\n            nograd_params += p.data.numel()\n        param_size += p.data.numel() * p.data.element_size()\n    for b in module.buffers():\n        num_buffers += b.numel()\n        buffer_size += b.numel() * b.element_size()\n\n    # Update cumulative receptive field\n    _rf, _s, _p = 1, 1, 0\n    for fw_idx, _layer in enumerate(info):\n        _rf += _s * (_layer[\"rf\"] - 1)\n        _p += _s * _layer[\"p\"]\n        _s *= _layer[\"s\"]\n        info[fw_idx][\"rf\"] = _rf\n        info[fw_idx][\"s\"] = _s\n        info[fw_idx][\"p\"] = _p\n\n    return {\n        \"overheads\": {\n            \"cuda\": {\n                \"pre\": cuda_overhead,\n                \"fwd\": get_process_gpu_ram(os.getpid()) - reserved_ram,\n            },\n            \"framework\": {\"pre\": framework_overhead, \"fwd\": diff_ram},\n        },\n        \"layers\": info,\n        \"overall\": {\n            \"grad_params\": grad_params,\n            \"nograd_params\": nograd_params,\n            \"param_size\": param_size,\n            \"num_buffers\": num_buffers,\n            \"buffer_size\": buffer_size,\n        },\n    }\n\n\ndef summary(\n    module: Module,\n    input_shape: Tuple[int, ...],\n    wrap_mode: str = \"mid\",\n    max_depth: Optional[int] = None,\n    receptive_field: bool = False,\n    effective_rf_stats: bool = False,\n) -> None:\n    \"\"\"Print module summary for an expected input tensor shape\n\n    >>> import torch.nn as nn\n    >>> from torchscan import summary\n    >>> mod = nn.Conv2d(3, 8, 3)\n    >>> summary(mod, (3, 224, 224), receptive_field=True)\n\n    Args:\n        module: module to inspect\n        input_shape: expected input shapes (don't include batch size)\n        wrap_mode: if a value is too long, where the wrapping should be performed\n        max_depth: maximum depth of layer information\n        receptive_field: whether receptive field estimation should be performed\n        effective_rf_stats: if `receptive_field` is True, displays effective stride and padding\n    \"\"\"\n    # Get the summary dict\n    module_info = crawl_module(module, input_shape)\n    # Aggregate until max_depth\n    if isinstance(max_depth, int):\n        module_info = aggregate_info(module_info, max_depth)\n    # Format it and print it\n    print(format_info(module_info, wrap_mode, receptive_field, effective_rf_stats))  # noqa T201\n"
  },
  {
    "path": "torchscan/modules/__init__.py",
    "content": "from .flops import *\nfrom .macs import *\nfrom .memory import *\nfrom .receptive import *\n"
  },
  {
    "path": "torchscan/modules/flops.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nimport warnings\nfrom functools import reduce\nfrom operator import mul\nfrom typing import Tuple\n\nimport torch\nfrom torch import Tensor, nn\nfrom torch.nn import Module\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.modules.conv import _ConvNd, _ConvTransposeNd\nfrom torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd\n\n__all__ = [\"module_flops\"]\n\n\ndef module_flops(module: Module, inputs: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"Estimate the number of floating point operations performed by the module\n\n    Args:\n        module: PyTorch module\n        inputs: input to the module\n        out: output of the module\n    Returns:\n        number of FLOPs\n    \"\"\"\n    if isinstance(module, (nn.Identity, nn.Flatten)):\n        return 0\n    if isinstance(module, nn.Linear):\n        return flops_linear(module, inputs)\n    if isinstance(module, nn.ReLU):\n        return flops_relu(module, inputs)\n    if isinstance(module, nn.ELU):\n        return flops_elu(module, inputs)\n    if isinstance(module, nn.LeakyReLU):\n        return flops_leakyrelu(module, inputs)\n    if isinstance(module, nn.ReLU6):\n        return flops_relu6(module, inputs)\n    if isinstance(module, nn.Tanh):\n        return flops_tanh(module, inputs)\n    if isinstance(module, nn.Sigmoid):\n        return flops_sigmoid(module, inputs)\n    if isinstance(module, _ConvTransposeNd):\n        return flops_convtransposend(module, inputs, out)\n    if isinstance(module, _ConvNd):\n        return flops_convnd(module, inputs, out)\n    if isinstance(module, _BatchNorm):\n        return flops_bn(module, inputs)\n    if isinstance(module, _MaxPoolNd):\n        return flops_maxpool(module, inputs, out)\n    if isinstance(module, _AvgPoolNd):\n        return flops_avgpool(module, inputs, out)\n    if isinstance(module, _AdaptiveMaxPoolNd):\n        return flops_adaptive_maxpool(module, inputs, out)\n    if isinstance(module, _AdaptiveAvgPoolNd):\n        return flops_adaptive_avgpool(module, inputs, out)\n    if isinstance(module, nn.Dropout):\n        return flops_dropout(module, inputs)\n    if isinstance(module, nn.Transformer):\n        return flops_transformer(module, inputs)\n    warnings.warn(f\"Module type not supported: {module.__class__.__name__}\", stacklevel=1)\n    return 0\n\n\ndef flops_linear(module: nn.Linear, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.Linear`\"\"\"\n    # batch size * out_chan * in_chan\n    num_out_feats = module.out_features * reduce(mul, inputs[0].shape[:-1])\n    mm_flops = num_out_feats * (2 * module.in_features - 1)\n    bias_flops = num_out_feats if module.bias is not None else 0\n\n    return mm_flops + bias_flops\n\n\ndef flops_sigmoid(_: nn.Sigmoid, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.Sigmoid`\"\"\"\n    # For each element, mul by -1, exp it, add 1, div\n    return inputs[0].numel() * 4\n\n\ndef flops_relu(_: nn.ReLU, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.ReLU`\"\"\"\n    # Each element is compared to 0\n    return inputs[0].numel()\n\n\ndef flops_elu(_: nn.ELU, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.ELU`\"\"\"\n    # For each element, compare it to 0, exp it, sub 1, mul by alpha, compare it to 0 and sum both\n    return inputs[0].numel() * 6\n\n\ndef flops_leakyrelu(_: nn.LeakyReLU, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.LeakyReLU`\"\"\"\n    # For each element, compare it to 0 (max), compare it to 0 (min), mul by slope and sum both\n    return inputs[0].numel() * 4\n\n\ndef flops_relu6(_: nn.ReLU6, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.ReLU6`\"\"\"\n    # For each element, compare it to 0 (max), compare it to 0 (min), mul by slope and sum both\n    return inputs[0].numel() * 2\n\n\ndef flops_tanh(_: nn.Tanh, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.Tanh`\"\"\"\n    # For each element, exp it, mul by -1 and exp it, divide the sub by the add\n    return inputs[0].numel() * 6\n\n\ndef flops_dropout(module: nn.Dropout, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.Dropout`\"\"\"\n    if module.p > 0:\n        # Sample a random number for each input element\n        return inputs[0].numel()\n    return 0\n\n\ndef flops_convtransposend(module: _ConvTransposeNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.conv._ConvTranposeNd`\"\"\"\n    # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)\n    # Define min and max sizes\n    padding_flops = len(module.kernel_size) * 8\n\n    # Once padding is determined, the operations are almost identical to those of a convolution\n    conv_flops = flops_convnd(module, inputs, out)\n\n    return padding_flops + conv_flops\n\n\ndef flops_convnd(module: _ConvNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.conv._ConvNd`\"\"\"\n    # For each position, # mult = kernel size, # adds = kernel size - 1\n    window_flops_per_chan = 2 * reduce(mul, module.kernel_size) - 1\n    # Connections to input channels is controlled by the group parameter\n    effective_in_chan = inputs[0].shape[1] // module.groups\n    # N * flops + (N - 1) additions\n    window_flops = effective_in_chan * window_flops_per_chan + (effective_in_chan - 1)\n    conv_flops = out.numel() * window_flops\n\n    # Each output element gets a bias addition\n    bias_flops = out.numel() if module.bias is not None else 0\n\n    return conv_flops + bias_flops\n\n\ndef flops_bn(module: _BatchNorm, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.batchnorm._BatchNorm`\"\"\"\n    # for each channel, add eps and running_var, sqrt it\n    norm_ops = module.num_features * 2\n    # For each element, sub running_mean, div by denom\n    norm_ops += inputs[0].numel() * 2\n    # For each element, mul by gamma, add beta\n    scale_ops = inputs[0].numel() * 2 if module.affine else 0\n    bn_flops = norm_ops + scale_ops\n\n    # Count tracking stats update ops\n    # cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101\n    tracking_flops = 0\n    if module.track_running_stats and module.training:\n        # exponential_average_factor\n        if module.momentum is None:\n            tracking_flops += 1\n        # running_mean: by channel, sum values and div by batch size\n        tracking_flops += inputs[0].numel()\n        # running_var: by channel, sub mean and square values, sum them, divide by batch size\n        tracking_flops += 3 * inputs[0].numel()\n        # Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1)\n        tracking_flops += 2 * module.num_features * 3\n\n    return bn_flops + tracking_flops\n\n\ndef flops_maxpool(module: _MaxPoolNd, _: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.pooling._MaxPoolNd`\"\"\"\n    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size\n\n    # for each spatial output element, check max element in kernel scope\n    return out.numel() * (k_size - 1)\n\n\ndef flops_avgpool(module: _AvgPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.pooling._AvgPoolNd`\"\"\"\n    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size\n\n    # for each spatial output element, sum elements in kernel scope and div by kernel size\n    return out.numel() * (k_size - 1 + inputs[0].ndim - 2)\n\n\ndef flops_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`\"\"\"\n    # Approximate kernel_size using ratio of spatial shapes between input and output\n    kernel_size = tuple(\n        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1\n        for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:], strict=False)\n    )\n\n    # for each spatial output element, check max element in kernel scope\n    return out.numel() * (reduce(mul, kernel_size) - 1)\n\n\ndef flops_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`\"\"\"\n    # Approximate kernel_size using ratio of spatial shapes between input and output\n    kernel_size = tuple(\n        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1\n        for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:], strict=False)\n    )\n\n    # for each spatial output element, sum elements in kernel scope and div by kernel size\n    return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))\n\n\ndef flops_layernorm(module: nn.LayerNorm, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.modules.batchnorm._BatchNorm`\"\"\"\n    # Compute current mean\n    norm_ops = reduce(mul, module.normalized_shape) * inputs[0].shape[: -len(module.normalized_shape)].numel()\n    # current var (sub the mean, square it, sum them, divide by remaining shape)\n    norm_ops += 3 * inputs[0].numel()\n    # for each channel, add eps and running_var, sqrt it\n    norm_ops += reduce(mul, module.normalized_shape) * 2\n    # For each element, sub running_mean, div by denom\n    norm_ops += inputs[0].numel() * 2\n    # For each element, mul by gamma, add beta\n    scale_ops = inputs[0].numel() * 2 if module.elementwise_affine else 0\n\n    return norm_ops + scale_ops\n\n\ndef flops_mha(module: nn.MultiheadAttention, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.MultiheadAttention`\"\"\"\n    # Input projection\n    q, k, _ = inputs[:3]\n    batch_size = q.shape[1]\n    if module._qkv_same_embed_dim:\n        tot_flops = 3 * flops_linear(\n            nn.Linear(\n                module.in_proj_weight.shape[1], module.in_proj_weight.shape[0], bias=module.in_proj_bias is not None\n            ),\n            (torch.empty((batch_size, module.in_proj_weight.shape[1])),),\n        )\n    else:\n        tot_flops = flops_linear(\n            nn.Linear(\n                module.q_proj_weight.shape[1], module.q_proj_weight.shape[0], bias=module.in_proj_bias is not None\n            ),\n            (torch.empty((batch_size, module.q_proj_weight.shape[1])),),\n        )\n        tot_flops += flops_linear(\n            nn.Linear(module.k_proj_weight.shape[1], module.k_proj_weight.shape[0], bias=module.bias_k is not None),\n            (torch.empty((batch_size, module.k_proj_weight.shape[1])),),\n        )\n        tot_flops += flops_linear(\n            nn.Linear(module.v_proj_weight.shape[1], module.v_proj_weight.shape[0], bias=module.bias_v is not None),\n            (torch.empty((batch_size, module.v_proj_weight.shape[1])),),\n        )\n\n    # Q (L, B, embed_dim) --> (B * num_heads, L, head_dim=embed_dim / num_heads)\n\n    # Scaled dot-product attention (cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L5083)\n    # sqrt the embedding dim and div the Q with it\n    tot_flops += 1 + batch_size * module.num_heads * module.head_dim * q.shape[0]\n    # batched matrix multiply\n    tot_flops += batch_size * module.num_heads * (q.shape[0] * k.shape[0]) * (2 * module.head_dim - 1)\n    # attention mask\n    if inputs[-1] is not None:\n        tot_flops += batch_size * module.num_heads * (q.shape[0] * k.shape[0])\n\n    # softmax\n    tot_flops += batch_size * module.num_heads * q.shape[0] * (3 * k.shape[0] - 1)\n    # dropout\n    if module.dropout > 0:\n        tot_flops += batch_size * module.num_heads * (q.shape[0] * k.shape[0])\n\n    # batched matrix multiply\n    tot_flops += batch_size * module.num_heads * (q.shape[0] * module.head_dim) * (2 * k.shape[0] - 1)\n    # Output linear projection\n    tot_flops += flops_linear(module.out_proj, (torch.empty((q.shape[0], module.out_proj.in_features)),))\n\n    return tot_flops\n\n\ndef flops_transformer_encoderlayer(module: nn.TransformerEncoderLayer, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.TransformerEncoderLayer`\"\"\"\n    tot_flops = flops_mha(module.self_attn, (inputs[0],) * 3)\n\n    tot_flops += flops_dropout(module.dropout1, inputs) + inputs[0].numel()\n    tot_flops += flops_layernorm(module.norm1, inputs)\n    # get linear 1 output size\n    tot_flops += flops_linear(module.linear1, inputs)\n    tot_flops += module_flops(module.activation, inputs, torch.empty(1))  # type: ignore[arg-type]\n    tot_flops += flops_dropout(module.dropout, inputs) + flops_linear(module.linear2, inputs)\n    # get linear 2 output size\n    tot_flops += flops_dropout(module.dropout2, inputs) + inputs[0].numel()\n    tot_flops += flops_layernorm(module.norm2, inputs)\n\n    return tot_flops\n\n\ndef flops_transformer_decoderlayer(module: nn.TransformerDecoderLayer, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.TransformerEncoderLayer`\"\"\"\n    tot_flops = flops_mha(module.self_attn, (inputs[0],) * 3)\n\n    tot_flops += flops_dropout(module.dropout1, inputs) + inputs[0].numel()\n    tot_flops += flops_layernorm(module.norm1, inputs)\n\n    tot_flops = flops_mha(module.multihead_attn, (inputs[0], inputs[1], inputs[1]))\n    tot_flops += flops_dropout(module.dropout2, inputs) + inputs[0].numel()\n    tot_flops += flops_layernorm(module.norm2, inputs)\n\n    # get linear 1 output size\n    tot_flops += flops_linear(module.linear1, inputs)\n    tot_flops += module_flops(module.activation, inputs, torch.empty(1))  # type: ignore[arg-type]\n    tot_flops += flops_dropout(module.dropout, inputs) + flops_linear(module.linear2, inputs)\n    # get linear 2 output size\n    tot_flops += flops_dropout(module.dropout3, inputs) + inputs[0].numel()\n    tot_flops += flops_layernorm(module.norm3, inputs)\n\n    return tot_flops\n\n\ndef flops_transformer(module: nn.Transformer, inputs: Tuple[Tensor, ...]) -> int:\n    \"\"\"FLOPs estimation for `torch.nn.Transformer`\"\"\"\n    encoder_flops = len(module.encoder.layers) * flops_transformer_encoderlayer(module.encoder.layers[0], inputs)\n\n    if module.encoder.norm is not None:\n        encoder_flops += flops_layernorm(module.encoder.norm, inputs)\n\n    decoder_flops = len(module.decoder.layers) * flops_transformer_decoderlayer(module.decoder.layers[0], inputs)\n\n    if module.decoder.norm is not None:\n        decoder_flops += flops_layernorm(module.decoder.norm, inputs)\n\n    return encoder_flops + decoder_flops\n"
  },
  {
    "path": "torchscan/modules/macs.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nimport warnings\nfrom functools import reduce\nfrom operator import mul\n\nfrom torch import Tensor, nn\nfrom torch.nn import Module\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.modules.conv import _ConvNd, _ConvTransposeNd\nfrom torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd\n\n__all__ = [\"module_macs\"]\n\n\ndef module_macs(module: Module, inp: Tensor, out: Tensor) -> int:\n    \"\"\"Estimate the number of multiply-accumulation operations performed by the module\n\n    Args:\n        module (torch.nn.Module): PyTorch module\n        inp (torch.Tensor): input to the module\n        out (torch.Tensor): output of the module\n    Returns:\n        int: number of MACs\n    \"\"\"\n    if isinstance(module, nn.Linear):\n        return macs_linear(module, inp, out)\n    if isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)):\n        return 0\n    if isinstance(module, _ConvTransposeNd):\n        return macs_convtransposend(module, inp, out)\n    if isinstance(module, _ConvNd):\n        return macs_convnd(module, inp, out)\n    if isinstance(module, _BatchNorm):\n        return macs_bn(module, inp, out)\n    if isinstance(module, _MaxPoolNd):\n        return macs_maxpool(module, inp, out)\n    if isinstance(module, _AvgPoolNd):\n        return macs_avgpool(module, inp, out)\n    if isinstance(module, _AdaptiveMaxPoolNd):\n        return macs_adaptive_maxpool(module, inp, out)\n    if isinstance(module, _AdaptiveAvgPoolNd):\n        return macs_adaptive_avgpool(module, inp, out)\n    if isinstance(module, nn.Dropout):\n        return 0\n    warnings.warn(f\"Module type not supported: {module.__class__.__name__}\", stacklevel=1)\n    return 0\n\n\ndef macs_linear(module: nn.Linear, _: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.Linear`\"\"\"\n    # batch size * out_chan * macs_per_elt (bias already counted in accumulation)\n    return module.in_features * reduce(mul, out.shape)\n\n\ndef macs_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.conv._ConvTransposeNd`\"\"\"\n    # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)\n    # Define min and max sizes, then subtract them\n    padding_macs = len(module.kernel_size) * 4\n\n    # Rest of the operations are almost identical to a convolution (given the padding)\n    conv_macs = macs_convnd(module, inp, out)\n\n    return padding_macs + conv_macs\n\n\ndef macs_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.conv._ConvNd`\"\"\"\n    # For each position, # mult = kernel size, # adds = kernel size - 1\n    window_macs_per_chan = reduce(mul, module.kernel_size)\n    # Connections to input channels is controlled by the group parameter\n    effective_in_chan = inp.shape[1] // module.groups\n    # N * mac\n    window_mac = effective_in_chan * window_macs_per_chan\n    return out.numel() * window_mac\n\n    # bias already counted in accumulation\n\n\ndef macs_bn(module: _BatchNorm, inp: Tensor, _: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.batchnorm._BatchNorm`\"\"\"\n    # sub mean, div by denom\n    norm_mac = 1\n    # mul by gamma, add beta\n    scale_mac = 1 if module.affine else 0\n\n    # Sum everything up\n    bn_mac = inp.numel() * (norm_mac + scale_mac)\n\n    # Count tracking stats update ops\n    # cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101\n    tracking_mac = 0\n    b = inp.shape[0]\n    num_spatial_elts = inp.shape[2:].numel()\n    if module.track_running_stats and module.training:\n        # running_mean: by channel, sum value and div by batch size\n        tracking_mac += module.num_features * (b * num_spatial_elts - 1)\n        # running_var: by channel, sub mean and square values, sum them, divide by batch size\n        active_elts = b * num_spatial_elts\n        tracking_mac += module.num_features * (2 * active_elts - 1)\n        # Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1)\n        tracking_mac += 2 * module.num_features * 2\n\n    return bn_mac + tracking_mac\n\n\ndef macs_maxpool(module: _MaxPoolNd, _: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.pooling._MaxPoolNd`\"\"\"\n    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size\n\n    # for each spatial output element, check max element in kernel scope\n    return out.numel() * (k_size - 1)\n\n\ndef macs_avgpool(module: _AvgPoolNd, inp: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.pooling._AvgPoolNd`\"\"\"\n    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size\n\n    # for each spatial output element, sum elements in kernel scope and div by kernel size\n    return out.numel() * (k_size - 1 + inp.ndim - 2)\n\n\ndef macs_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inp: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`\"\"\"\n    # Approximate kernel_size using ratio of spatial shapes between input and output\n    kernel_size = tuple(\n        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1\n        for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False)\n    )\n\n    # for each spatial output element, check max element in kernel scope\n    return out.numel() * (reduce(mul, kernel_size) - 1)\n\n\ndef macs_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inp: Tensor, out: Tensor) -> int:\n    \"\"\"MACs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`\"\"\"\n    # Approximate kernel_size using ratio of spatial shapes between input and output\n    kernel_size = tuple(\n        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1\n        for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False)\n    )\n\n    # for each spatial output element, sum elements in kernel scope and div by kernel size\n    return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))\n"
  },
  {
    "path": "torchscan/modules/memory.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nimport warnings\nfrom functools import reduce\nfrom operator import mul\nfrom typing import Union\n\nfrom torch import Tensor, nn\nfrom torch.nn import Module\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.modules.conv import _ConvNd, _ConvTransposeNd\nfrom torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd\n\n__all__ = [\"module_dmas\"]\n\n\ndef module_dmas(module: Module, inp: Tensor, out: Tensor) -> int:\n    \"\"\"Estimate the number of direct memory accesses by the module.\n    The implementation overhead is neglected.\n\n    Args:\n        module (torch.nn.Module): PyTorch module\n        inp (torch.Tensor): input to the module\n        out (torch.Tensor): output of the module\n    Returns:\n        int: number of DMAs\n    \"\"\"\n    if isinstance(module, nn.Identity):\n        return dmas_identity(module, inp, out)\n    if isinstance(module, nn.Flatten):\n        return dmas_flatten(module, inp, out)\n    if isinstance(module, nn.Linear):\n        return dmas_linear(module, inp, out)\n    if isinstance(module, (nn.ReLU, nn.ReLU6)):\n        return dmas_relu(module, inp, out)\n    if isinstance(module, (nn.ELU, nn.LeakyReLU)):\n        return dmas_act_single_param(module, inp, out)\n    if isinstance(module, nn.Sigmoid):\n        return dmas_sigmoid(module, inp, out)\n    if isinstance(module, nn.Tanh):\n        return dmas_tanh(module, inp, out)\n    if isinstance(module, _ConvTransposeNd):\n        return dmas_convtransposend(module, inp, out)\n    if isinstance(module, _ConvNd):\n        return dmas_convnd(module, inp, out)\n    if isinstance(module, _BatchNorm):\n        return dmas_bn(module, inp, out)\n    if isinstance(module, (_MaxPoolNd, _AvgPoolNd)):\n        return dmas_pool(module, inp, out)\n    if isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):\n        return dmas_adaptive_pool(module, inp, out)\n    if isinstance(module, nn.Dropout):\n        return dmas_dropout(module, inp, out)\n    warnings.warn(f\"Module type not supported: {module.__class__.__name__}\", stacklevel=1)\n    return 0\n\n\ndef num_params(module: Module) -> int:\n    \"\"\"Compute the number of parameters\n\n    Args:\n        module (torch.nn.Module): PyTorch module\n    Returns:\n        int: number of parameter elements\n    \"\"\"\n    return sum(p.data.numel() for p in module.parameters())\n\n\ndef dmas_identity(_: nn.Identity, inp: Tensor, __: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.Identity`\"\"\"\n    return inp.numel()\n\n\ndef dmas_flatten(_: nn.Flatten, inp: Tensor, __: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.Flatten`\"\"\"\n    return 2 * inp.numel()\n\n\ndef dmas_linear(module: nn.Linear, inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.Linear`\"\"\"\n    input_dma = inp.numel()\n    # Access weight and bias\n    ops_dma = num_params(module)\n    output_dma = out.numel()\n\n    return input_dma + ops_dma + output_dma\n\n\ndef dmas_relu(module: Union[nn.ReLU, nn.ReLU6], inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.ReLU`\"\"\"\n    input_dma = inp.numel()\n    output_dma = 0 if module.inplace else out.numel()\n\n    return input_dma + output_dma\n\n\ndef dmas_act_single_param(module: Union[nn.ELU, nn.LeakyReLU], inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for activations with single parameter\"\"\"\n    input_dma = inp.numel()\n    # Access alpha, slope or other\n    ops_dma = 1\n    output_dma = 0 if module.inplace else out.numel()\n\n    return input_dma + ops_dma + output_dma\n\n\ndef dmas_sigmoid(_: nn.Sigmoid, inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.Sigmoid`\"\"\"\n    # Access for both exp\n    input_dma = inp.numel()\n    output_dma = out.numel()\n\n    return input_dma + output_dma\n\n\ndef dmas_tanh(_: nn.Tanh, inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.Tanh`\"\"\"\n    # Access for both exp\n    input_dma = inp.numel() * 2\n    output_dma = out.numel()\n\n    return input_dma + output_dma\n\n\ndef dmas_dropout(module: nn.Dropout, inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.Dropout`\"\"\"\n    input_dma = inp.numel()\n\n    # Access sampling probability\n    ops_dma = 1\n\n    output_dma = 0 if module.inplace else out.numel()\n\n    return input_dma + ops_dma + output_dma\n\n\ndef dmas_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.modules.conv._ConvTransposeNd`\"\"\"\n    # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)\n    # Access stride, padding and kernel_size\n    in_padding = len(module.kernel_size) * 4\n    out_padding = len(module.kernel_size)\n\n    # The rest is like a classic convolution\n    conv_dmas = dmas_convnd(module, inp, out)\n\n    return in_padding + out_padding + conv_dmas\n\n\ndef dmas_convnd(module: _ConvNd, _: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.modules.conv._ConvNd`\"\"\"\n    # Each output element required K ** 2 memory access of each input channel\n    input_dma = module.in_channels * reduce(mul, module.kernel_size) * out.numel()\n    # Correct with groups\n    input_dma //= module.groups\n\n    # Access weight & bias\n    ops_dma = num_params(module)\n    output_dma = out.numel()\n\n    return input_dma + ops_dma + output_dma\n\n\ndef dmas_bn(module: _BatchNorm, inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for `torch.nn.modules.batchnorm._BatchNorm`\"\"\"\n    input_dma = inp.numel()\n\n    # Access running_mean, running_var and eps\n    ops_dma = module.running_mean.numel() + module.running_var.numel() + 1  # type: ignore[union-attr]\n    # Access to weight and bias\n    if module.affine:\n        ops_dma += module.weight.data.numel() + module.bias.data.numel()\n    # Exp avg factor\n    if module.momentum:\n        ops_dma += 1\n    # Update stats\n    if module.training and module.track_running_stats:\n        # Current mean and std computation only requires access to input, already counted in input_dma\n        # Update num of batches and running stats\n        ops_dma += 1 + module.running_mean.numel() + module.running_var.numel()  # type: ignore[union-attr]\n\n    output_dma = out.numel()\n\n    return input_dma + ops_dma + output_dma\n\n\ndef dmas_pool(module: Union[_MaxPoolNd, _AvgPoolNd], inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for spatial pooling modules\"\"\"\n    # Resolve kernel size and stride size (can be stored as a single integer or a tuple)\n    if isinstance(module.kernel_size, tuple):\n        kernel_size = module.kernel_size\n    elif isinstance(module.kernel_size, int):\n        kernel_size = (module.kernel_size,) * (inp.ndim - 2)\n\n    # Each output element required K ** 2 memory accesses\n    input_dma = reduce(mul, kernel_size) * out.numel()\n\n    output_dma = out.numel()\n\n    return input_dma + output_dma\n\n\ndef dmas_adaptive_pool(_: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor) -> int:\n    \"\"\"DMAs estimation for adaptive spatial pooling modules\"\"\"\n    # Approximate kernel_size using ratio of spatial shapes between input and output\n    kernel_size = tuple(\n        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1\n        for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False)\n    )\n    # Each output element required K ** 2 memory accesses\n    input_dma = reduce(mul, kernel_size) * out.numel()\n\n    output_dma = out.numel()\n\n    return input_dma + output_dma\n"
  },
  {
    "path": "torchscan/modules/receptive.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nimport math\nimport warnings\nfrom typing import Tuple, Union\n\nfrom torch import Tensor, nn\nfrom torch.nn import Module\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.modules.conv import _ConvNd, _ConvTransposeNd\nfrom torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd\n\n__all__ = [\"module_rf\"]\n\n\ndef module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]:\n    \"\"\"Estimate the spatial receptive field of the module\n\n    Args:\n        module (torch.nn.Module): PyTorch module\n        inp (torch.Tensor): input to the module\n        out (torch.Tensor): output of the module\n    Returns:\n        receptive field\n        effective stride\n        effective padding\n    \"\"\"\n    if isinstance(\n        module,\n        (\n            nn.Identity,\n            nn.Flatten,\n            nn.ReLU,\n            nn.ELU,\n            nn.LeakyReLU,\n            nn.ReLU6,\n            nn.Tanh,\n            nn.Sigmoid,\n            _BatchNorm,\n            nn.Dropout,\n            nn.Linear,\n        ),\n    ):\n        return 1.0, 1.0, 0.0\n    if isinstance(module, _ConvTransposeNd):\n        return rf_convtransposend(module, inp, out)\n    if isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)):\n        return rf_aggregnd(module, inp, out)\n    if isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):\n        return rf_adaptive_poolnd(module, inp, out)\n    warnings.warn(f\"Module type not supported: {module.__class__.__name__}\", stacklevel=1)\n    return 1.0, 1.0, 0.0\n\n\ndef rf_convtransposend(module: _ConvTransposeNd, _: Tensor, __: Tensor) -> Tuple[float, float, float]:\n    k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size\n    s = module.stride[0] if isinstance(module.stride, tuple) else module.stride\n    return -k, 1.0 / s, 0.0\n\n\ndef rf_aggregnd(module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor) -> Tuple[float, float, float]:\n    k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size\n    if hasattr(module, \"dilation\"):\n        d = module.dilation[0] if isinstance(module.dilation, tuple) else module.dilation\n        k = d * (k - 1) + 1\n    s = module.stride[0] if isinstance(module.stride, tuple) else module.stride\n    p = module.padding[0] if isinstance(module.padding, tuple) else module.padding\n    return k, s, p  # type: ignore[return-value]\n\n\ndef rf_adaptive_poolnd(\n    _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor\n) -> Tuple[int, int, float]:\n    stride = math.ceil(inp.shape[-1] / out.shape[-1])\n    kernel_size = stride\n    padding = (inp.shape[-1] - kernel_size * stride) / 2\n\n    return kernel_size, stride, padding\n"
  },
  {
    "path": "torchscan/process/__init__.py",
    "content": "from .memory import *\n"
  },
  {
    "path": "torchscan/process/memory.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nimport re\nimport subprocess  # noqa S404\nimport warnings\n\nimport torch\n\n__all__ = [\"get_process_gpu_ram\"]\n\n\ndef get_process_gpu_ram(pid: int) -> float:\n    \"\"\"Gets the amount of RAM used by a given process on GPU devices\n\n    Args:\n        pid: process ID\n    Returns:\n        RAM usage in Megabytes\n    \"\"\"\n    # PyTorch is not responsible for GPU usage\n    if not torch.cuda.is_available():\n        warnings.warn(\"CUDA is unavailable to PyTorch.\", stacklevel=1)\n        return 0.0\n\n    # Query the running processes on GPUs\n    try:\n        res = subprocess.run([\"nvidia-smi\", \"-q\", \"-d\", \"PIDS\"], capture_output=True).stdout.decode()\n        # Try to locate the process\n        pids = re.findall(r\"Process ID\\s+:\\s([^\\D]*)\", res)\n        for idx, _pid in enumerate(pids):\n            if int(_pid) == pid:\n                return float(re.findall(r\"Used GPU Memory\\s+:\\s([^\\D]*)\", res)[idx])\n\n        # Query total memory used by nvidia\n        res = subprocess.run(\n            [\"nvidia-smi\", \"--query-gpu=memory.used\", \"--format=csv\"], capture_output=True\n        ).stdout.decode()\n        return float(res.split(\"\\n\")[1].split()[0])\n    except FileNotFoundError as e:\n        warnings.warn(f\"raised: {e}. Parsing NVIDIA-SMI failed.\", stacklevel=1)\n\n    # Default to overall RAM usage for this process on the GPU\n    ram_str = torch.cuda.list_gpu_processes().split(\"\\n\")\n    # Take the first process running on the GPU\n    if ram_str[1].startswith(\"process\"):\n        return float(ram_str[1].split()[3])\n\n    # Otherwise assume the process is running exclusively on CPU\n    return 0.0\n"
  },
  {
    "path": "torchscan/utils.py",
    "content": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.\n\nfrom itertools import starmap\nfrom typing import Any, Dict, List, Optional, Tuple\n\n\ndef format_name(name: str, depth: int = 0) -> str:\n    \"\"\"Format a string for nested data printing\n\n    Args:\n        name: input string\n        depth: depth of the nested information\n    Returns:\n        formatted string\n    \"\"\"\n    if depth == 0:\n        return name\n    if depth == 1:\n        return f\"├─{name}\"\n    return f\"{'|    ' * (depth - 1)}└─{name}\"\n\n\ndef wrap_string(s: str, max_len: int, delimiter: str = \".\", wrap: str = \"[...]\", mode: str = \"end\") -> str:\n    \"\"\"Wrap a string into a given length\n\n    Args:\n        s: input string\n        max_len: maximum string length\n        delimiter: character used for delimiting information categories\n        wrap: wrapping sequence used\n        mode: wrapping mode\n    Returns:\n        wrapped string\n    \"\"\"\n    if len(s) <= max_len or mode is None:\n        return s\n\n    if mode == \"end\":\n        return s[: max_len - len(wrap)] + wrap\n    if mode == \"mid\":\n        final_part = s.rpartition(delimiter)[-1]\n        wrapped_end = f\"{wrap}.{final_part}\"\n        return s[: max_len - len(wrapped_end)] + wrapped_end\n    raise ValueError(\"received an unexpected value of argument `mode`\")\n\n\ndef unit_scale(val: float) -> Tuple[float, str]:\n    \"\"\"Rescale value using scale units\n\n    Args:\n        val: input value\n    Returns:\n        tuple of rescaled value and unit\n    \"\"\"\n    if val // 1e12 > 0:\n        return val / 1e12, \"T\"\n    if val // 1e9 > 0:\n        return val / 1e9, \"G\"\n    if val // 1e6 > 0:\n        return val / 1e6, \"M\"\n    if val // 1e3 > 0:\n        return val / 1e3, \"k\"\n    return val, \"\"\n\n\ndef format_s(f_string: str, min_w: Optional[int] = None, max_w: Optional[int] = None) -> str:\n    \"\"\"Format number strings\"\"\"\n    if isinstance(min_w, int):\n        f_string = f\"{f_string:<{min_w}}\"\n    if isinstance(max_w, int):\n        f_string = f\"{f_string:.{max_w}}\"\n\n    return f_string\n\n\ndef format_line_str(\n    layer: Dict[str, Any],\n    col_w: Optional[List[int]] = None,\n    wrap_mode: str = \"mid\",\n    receptive_field: bool = False,\n    effective_rf_stats: bool = False,\n) -> List[str]:\n    \"\"\"Wrap all information into multiple lines\"\"\"\n    if not isinstance(col_w, list):\n        col_w = [None] * 7  # type: ignore[list-item]\n\n    max_len = col_w[0] + 3 if isinstance(col_w[0], int) else 100\n    line_str = [\n        format_s(wrap_string(format_name(layer[\"name\"], layer[\"depth\"]), max_len, mode=wrap_mode), col_w[0], col_w[0]),\n        format_s(layer[\"type\"], col_w[1], col_w[1]),\n        format_s(str(layer[\"output_shape\"]), col_w[2], col_w[2]),\n        format_s(f\"{layer['grad_params'] + layer['nograd_params'] + layer['num_buffers']:,}\", col_w[3], col_w[3]),\n    ]\n\n    if receptive_field:\n        line_str.append(format_s(f\"{layer['rf']:.0f}\", col_w[4], col_w[4]))\n        if effective_rf_stats:\n            line_str.extend((\n                format_s(f\"{layer['s']:.0f}\", col_w[5], col_w[5]),\n                format_s(f\"{layer['p']:.0f}\", col_w[6], col_w[6]),\n            ))\n\n    return line_str\n\n\ndef format_info(\n    module_info: Dict[str, Any], wrap_mode: str = \"mid\", receptive_field: bool = False, effective_rf_stats: bool = False\n) -> str:\n    \"\"\"Print module summary for an expected input tensor shape\n\n    Args:\n        module_info: dictionary output of `crawl_module`\n        wrap_mode: wrapping mode\n        receptive_field: whether to display receptive field\n        effective_rf_stats: if `receptive_field` is True, displays effective stride and padding\n    Returns:\n        formatted information\n    \"\"\"\n    # Set margin between cols\n    margin = 4\n    # Dynamic col width\n    # Init with headers\n    headers = [\"Layer\", \"Type\", \"Output Shape\", \"Param #\", \"Receptive field\", \"Effective stride\", \"Effective padding\"]\n    max_w = [27, 20, 25, 15, 15, 16, 17]\n    col_w = [len(s) for s in headers]\n    for layer in module_info[\"layers\"]:\n        col_w = [\n            max(v, len(s))\n            for v, s in zip(\n                col_w,\n                format_line_str(layer, col_w=None, wrap_mode=wrap_mode, receptive_field=True, effective_rf_stats=True),\n                strict=False,\n            )\n        ]\n\n    # Truncate columns that are too long\n    col_w = list(starmap(min, zip(col_w, max_w, strict=False)))\n\n    if not receptive_field:\n        col_w = col_w[:4]\n        headers = headers[:4]\n    elif not effective_rf_stats:\n        col_w = col_w[:5]\n        headers = headers[:5]\n\n    # Define separating lines\n    line_length = sum(col_w) + (len(col_w) - 1) * margin\n    thin_line = \"_\" * line_length\n    thick_line = \"=\" * line_length\n    dot_line = \"-\" * line_length\n\n    margin_str = \" \" * margin\n\n    # Header\n    info_str = [\n        thin_line,\n        margin_str.join([f\"{col_name:<{col_w}}\" for col_name, col_w in zip(headers, col_w, strict=False)]),\n        thick_line,\n    ]\n\n    # Layers\n    for layer in module_info[\"layers\"]:\n        line_str = format_line_str(layer, col_w, wrap_mode, receptive_field, effective_rf_stats)\n        info_str.append((\" \" * margin).join(line_str))\n\n    # Parameter information\n    num_params = module_info[\"overall\"][\"grad_params\"] + module_info[\"overall\"][\"nograd_params\"]\n    info_str.extend((\n        thick_line,\n        f\"Trainable params: {module_info['overall']['grad_params']:,}\",\n        f\"Non-trainable params: {module_info['overall']['nograd_params']:,}\",\n        f\"Total params: {num_params:,}\",\n    ))\n\n    # Static RAM usage\n    info_str.append(dot_line)\n\n    # Convert to Megabytes\n    param_size = (module_info[\"overall\"][\"param_size\"] + module_info[\"overall\"][\"buffer_size\"]) / 1024**2\n    overhead = module_info[\"overheads\"][\"framework\"][\"fwd\"] + module_info[\"overheads\"][\"cuda\"][\"fwd\"]\n\n    info_str.extend((\n        f\"Model size (params + buffers): {param_size:.2f} Mb\",\n        f\"Framework & CUDA overhead: {overhead:.2f} Mb\",\n        f\"Total RAM usage: {param_size + overhead:.2f} Mb\",\n    ))\n\n    # FLOPS information\n    info_str.append(dot_line)\n\n    flops, flops_units = unit_scale(sum(layer[\"flops\"] for layer in module_info[\"layers\"]))\n    macs, macs_units = unit_scale(sum(layer[\"macs\"] for layer in module_info[\"layers\"]))\n    dmas, dmas_units = unit_scale(sum(layer[\"dmas\"] for layer in module_info[\"layers\"]))\n\n    info_str.extend((\n        f\"Floating Point Operations on forward: {flops:.2f} {flops_units}FLOPs\",\n        f\"Multiply-Accumulations on forward: {macs:.2f} {macs_units}MACs\",\n        f\"Direct memory accesses on forward: {dmas:.2f} {dmas_units}DMAs\",\n        thin_line,\n    ))\n\n    return \"\\n\".join(info_str)\n\n\ndef aggregate_info(info: Dict[str, Any], max_depth: int) -> Dict[str, Any]:\n    \"\"\"Aggregate module information to a maximum depth\n\n    Args:\n        info: dictionary output of `crawl_module`\n        max_depth: depth at which parent node aggregates children information\n    Returns:\n        edited dictionary information\n    \"\"\"\n    if not any(layer[\"depth\"] == max_depth for layer in info[\"layers\"]):\n        raise ValueError(\"The `max_depth` argument cannot be higher than module depth.\")\n\n    for fw_idx, layer in enumerate(info[\"layers\"]):\n        # Need to aggregate information\n        if not layer[\"is_leaf\"] and layer[\"depth\"] == max_depth:\n            grad_p, nograd_p, p_size, num_buffers, b_size = 0, 0, 0, 0, 0\n            flops, macs, dmas = 0, 0, 0\n            for _layer in info[\"layers\"][fw_idx + 1 :]:\n                # Children have superior depth and were hooked after parent\n                if _layer[\"depth\"] <= max_depth:\n                    break\n                # Aggregate all information (flops, macc, ram)\n                flops += _layer[\"flops\"]\n                macs += _layer[\"macs\"]\n                dmas += _layer[\"dmas\"]\n                grad_p += _layer[\"grad_params\"]\n                nograd_p += _layer[\"nograd_params\"]\n                p_size += _layer[\"param_size\"]\n                num_buffers += _layer[\"num_buffers\"]\n                b_size += _layer[\"buffer_size\"]\n                # Take last child effective RF\n                _rf, _s, _p = _layer[\"rf\"], _layer[\"s\"], _layer[\"p\"]\n\n            # Update info\n            info[\"layers\"][fw_idx][\"flops\"] = flops\n            info[\"layers\"][fw_idx][\"macs\"] = macs\n            info[\"layers\"][fw_idx][\"dmas\"] = dmas\n            info[\"layers\"][fw_idx][\"rf\"] = _rf\n            info[\"layers\"][fw_idx][\"s\"] = _s\n            info[\"layers\"][fw_idx][\"p\"] = _p\n            info[\"layers\"][fw_idx][\"grad_params\"] = grad_p\n            info[\"layers\"][fw_idx][\"nograd_params\"] = nograd_p\n            info[\"layers\"][fw_idx][\"param_size\"] = p_size\n            info[\"layers\"][fw_idx][\"num_buffers\"] = num_buffers\n            info[\"layers\"][fw_idx][\"buffer_size\"] = b_size\n\n    # Filter out further depth information\n    info[\"layers\"] = [layer for layer in info[\"layers\"] if layer[\"depth\"] <= max_depth]\n\n    return info\n"
  }
]