Full Code of frgfm/torch-scan for AI

main 56246d9de511 cached
57 files
169.0 KB
49.3k tokens
106 symbols
1 requests
Download .txt
Repository: frgfm/torch-scan
Branch: main
Commit: 56246d9de511
Files: 57
Total size: 169.0 KB

Directory structure:
gitextract_xv6br66t/

├── .conda/
│   └── meta.yaml
├── .github/
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.yml
│   │   ├── config.yml
│   │   └── feature_request.yml
│   ├── PULL_REQUEST_TEMPLATE.md
│   ├── collect_env.py
│   ├── dependabot.yml
│   ├── labeler.yml
│   ├── release.yml
│   ├── verify_labels.py
│   └── workflows/
│       ├── builds.yml
│       ├── doc-status.yml
│       ├── docs.yml
│       ├── pr-labels.yml
│       ├── publish.yml
│       ├── pull_requests.yml
│       ├── style.yml
│       └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── docs/
│   ├── Makefile
│   ├── README.md
│   ├── build.sh
│   ├── make.bat
│   └── source/
│       ├── _static/
│       │   ├── css/
│       │   │   └── custom.css
│       │   └── js/
│       │       └── custom.js
│       ├── changelog.rst
│       ├── conf.py
│       ├── index.rst
│       ├── installing.rst
│       ├── modules.rst
│       ├── process.rst
│       ├── torchscan.rst
│       └── utils.rst
├── pyproject.toml
├── scripts/
│   └── benchmark.py
├── setup.py
├── tests/
│   ├── test_crawler.py
│   ├── test_modules.py
│   ├── test_process.py
│   └── test_utils.py
└── torchscan/
    ├── __init__.py
    ├── crawler.py
    ├── modules/
    │   ├── __init__.py
    │   ├── flops.py
    │   ├── macs.py
    │   ├── memory.py
    │   └── receptive.py
    ├── process/
    │   ├── __init__.py
    │   └── memory.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .conda/meta.yaml
================================================
{% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %}
{% set project = pyproject.get('project') %}
{% set urls = pyproject.get('project', {}).get('urls') %}
{% set version = environ.get('BUILD_VERSION', '0.2.0.dev0') %}
package:
  name: {{ project.get('name') }}
  version: {{ version }}

source:
  fn: {{ project.get('name') }}-{{ version }}}.tar.gz
  url: ../dist/{{ project.get('name') }}-{{ version }}.tar.gz

build:
  noarch: python
  script: python setup.py install --single-version-externally-managed --record=record.txt

requirements:
  host:
    - python>=3.8, <4.0
    - setuptools

  run:
    - pytorch >=2.0.0, <3.0.0

test:
  # Python imports
  imports:
    - torchscan
    - torchscan.modules
    - torchscan.process
    - torchscan.utils
  requires:
    - python

about:
  home: {{ urls.get('repository') }}
  license: Apache 2.0
  license_file: {{ project.get('license', {}).get('file') }}
  summary: {{ project.get('description') }}
  # description: |
  #   {{ data['long_description'] | replace("\n", "\n    ") | replace("#", '\#')}}
  doc_url: {{ urls.get('documentation') }}
  dev_url: {{ urls.get('repository') }}


================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms

github: frgfm
patreon: # Replace with a single Patreon username
open_collective: # Replace with an OpenCollective account
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']


================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.yml
================================================
name: 🐛 Bug report
description: Create a report to help us improve the library
labels: 'type: bug'
assignees: frgfm

body:
- type: markdown
  attributes:
    value: >
      #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/frgfm/torch-cam/issues?q=is%3Aissue).
- type: textarea
  attributes:
    label: Bug description
    description: |
      A clear and concise description of what the bug is.

      Please explain the result you observed and the behavior you were expecting.
    placeholder: |
      A clear and concise description of what the bug is.
  validations:
    required: true

- type: textarea
  attributes:
    label: Code snippet to reproduce the bug
    description: |
      Sample code to reproduce the problem.

      Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability.
    placeholder: |
      ```python
      Sample code to reproduce the problem
      ```
  validations:
    required: true
- type: textarea
  attributes:
    label: Error traceback
    description: |
      The error message you received running the code snippet, with the full traceback.

      Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability.
    placeholder: |
      ```
      The error message you got, with the full traceback.
      ```
  validations:
    required: true
- type: textarea
  attributes:
    label: Environment
    description: |
      Please run the following command and paste the output below.
      ```sh
      wget https://raw.githubusercontent.com/frgfm/torch-scan/main/.github/collect_env.py
      # For security purposes, please check the contents of collect_env.py before running it.
      python collect_env.py
      ```
  validations:
    required: true
- type: markdown
  attributes:
    value: >
      Thanks for helping us improve the library!


================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: true
contact_links:
  - name: Usage questions
    url: https://github.com/frgfm/torch-scan/discussions
    about: Ask questions and discuss with other TorchCAM community members


================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.yml
================================================
name: 🚀 Feature request
description: Submit a proposal/request for a new feature
labels: 'type: enhancement'
assignees: frgfm

body:
- type: textarea
  attributes:
    label: 🚀 Feature
    description: >
      A clear and concise description of the feature proposal
  validations:
    required: true
- type: textarea
  attributes:
    label: Motivation & pitch
    description: >
      Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
  validations:
    required: true
- type: textarea
  attributes:
    label: Alternatives
    description: >
      A description of any alternative solutions or features you've considered, if any.
- type: textarea
  attributes:
    label: Additional context
    description: >
      Add any other context or screenshots about the feature request.
- type: markdown
  attributes:
    value: >
      Thanks for contributing 🎉


================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
# What does this PR do?

<!--
Well, hello there! Thank you for proposing modifications to the project.

Make sure to have both a short descriptive title & explain your modifications with the relevant context. Make sure to include reference to Github issues it is related to. For the sake of keeping the library light, if you modified existing dependencies or added new ones, please state it clearly in your description.

-->

<!-- Remove if not applicable -->

Closes # (issue)


## Before submitting
- [ ] Was this discussed/approved in a Github [issue](https://github.com/frgfm/torch-scan/issues?q=is%3Aissue) or a [discussion](https://github.com/frgfm/torch-scan/discussions)? Please add a link to it if that's the case.
- [ ] You have read the [contribution guidelines](https://github.com/frgfm/torch-scan/blob/main/CONTRIBUTING.md#submitting-a-pull-request) and followed them in this PR.
- [ ] Did you make sure to update the documentation with your changes? Here are the
      [documentation guidelines](https://github.com/frgm/torch-scan/tree/main/docs).
- [ ] Did you write any new necessary tests?


================================================
FILE: .github/collect_env.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

"""
Based on https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py
This script outputs relevant system environment info
Run it with `python collect_env.py`.
"""

from __future__ import absolute_import, division, print_function, unicode_literals

import locale
import os
import re
import subprocess  # noqa S404
import sys
from pathlib import Path
from typing import NamedTuple

try:
    import torchscan

    TORCHSCAN_AVAILABLE = True
except (ImportError, NameError, AttributeError, OSError):
    TORCHSCAN_AVAILABLE = False

try:
    import torch

    TORCH_AVAILABLE = True
except (ImportError, NameError, AttributeError, OSError):
    TORCH_AVAILABLE = False

PY3 = sys.version_info >= (3, 0)


# System Environment Information
class SystemEnv(NamedTuple):
    torchscan_version: str
    torch_version: str
    os: str
    python_version: str
    is_cuda_available: bool
    cuda_runtime_version: str
    nvidia_driver_version: str
    nvidia_gpu_models: str
    cudnn_version: str


def run(command):
    """Returns (return-code, stdout, stderr)"""
    p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    output, err = p.communicate()
    rc = p.returncode
    if PY3:
        enc = locale.getpreferredencoding()
        output = output.decode(enc)
        err = err.decode(enc)
    return rc, output.strip(), err.strip()


def run_and_read_all(run_lambda, command):
    """Runs command using run_lambda; reads and returns entire output if rc is 0"""
    rc, out, _ = run_lambda(command)
    if rc != 0:
        return None
    return out


def run_and_parse_first_match(run_lambda, command, regex):
    """Runs command using run_lambda, returns the first regex match if it exists"""
    rc, out, _ = run_lambda(command)
    if rc != 0:
        return None
    match = re.search(regex, out)
    if match is None:
        return None
    return match.group(1)


def get_nvidia_driver_version(run_lambda):
    if get_platform() == "darwin":
        cmd = "kextstat | grep -i cuda"
        return run_and_parse_first_match(run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]")
    smi = get_nvidia_smi()
    return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")


def get_gpu_info(run_lambda):
    if get_platform() == "darwin":
        if TORCH_AVAILABLE and torch.cuda.is_available():
            return torch.cuda.get_device_name(None)
        return None
    smi = get_nvidia_smi()
    uuid_regex = re.compile(r" \(UUID: .+?\)")
    rc, out, _ = run_lambda(smi + " -L")
    if rc != 0:
        return None
    # Anonymize GPUs by removing their UUID
    return re.sub(uuid_regex, "", out)


def get_running_cuda_version(run_lambda):
    return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")


def get_cudnn_version(run_lambda):
    """This will return a list of libcudnn.so; it's hard to tell which one is being used"""
    if get_platform() == "win32":
        cudnn_cmd = 'where /R "%CUDA_PATH%\\bin" cudnn*.dll'
    elif get_platform() == "darwin":
        # CUDA libraries and drivers can be found in /usr/local/cuda/. See
        # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
        # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
        # Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
        cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
    else:
        cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
    rc, out, _ = run_lambda(cudnn_cmd)
    # find will return 1 if there are permission errors or if not found
    if len(out) == 0 or rc not in (1, 0):
        lib = os.environ.get("CUDNN_LIBRARY")
        if lib is not None and Path(lib).is_file():
            return os.path.realpath(lib)
        return None
    files = set()
    for fn in out.split("\n"):
        fn = os.path.realpath(fn)  # eliminate symbolic links
        if Path(fn).is_file():
            files.add(fn)
    if not files:
        return None
    # Alphabetize the result because the order is non-deterministic otherwise
    files = sorted(files)
    if len(files) == 1:
        return files[0]
    result = "\n".join(files)
    return "Probably one of the following:\n{}".format(result)


def get_nvidia_smi():
    # Note: nvidia-smi is currently available only on Windows and Linux
    smi = "nvidia-smi"
    if get_platform() == "win32":
        system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
        program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
        legacy_path = Path(program_files_root) / "NVIDIA Corporation" / "NVSMI" / smi
        new_path = Path(system_root) / "System32" / smi
        smis = [new_path, legacy_path]
        for candidate_smi in smis:
            if Path(candidate_smi).exists():
                smi = '"{}"'.format(candidate_smi)
                break
    return smi


def get_platform():
    if sys.platform.startswith("linux"):
        return "linux"
    if sys.platform.startswith("win32"):
        return "win32"
    if sys.platform.startswith("cygwin"):
        return "cygwin"
    if sys.platform.startswith("darwin"):
        return "darwin"
    return sys.platform


def get_mac_version(run_lambda):
    return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")


def get_windows_version(run_lambda):
    return run_and_read_all(run_lambda, "wmic os get Caption | findstr /v Caption")


def get_lsb_version(run_lambda):
    return run_and_parse_first_match(run_lambda, "lsb_release -a", r"Description:\t(.*)")


def check_release_file(run_lambda):
    return run_and_parse_first_match(run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"')


def get_os(run_lambda):
    platform = get_platform()

    if platform in ("win32", "cygwin"):
        return get_windows_version(run_lambda)

    if platform == "darwin":
        version = get_mac_version(run_lambda)
        if version is None:
            return None
        return "Mac OSX {}".format(version)

    if platform == "linux":
        # Ubuntu/Debian based
        desc = get_lsb_version(run_lambda)
        if desc is not None:
            return desc

        # Try reading /etc/*-release
        desc = check_release_file(run_lambda)
        if desc is not None:
            return desc

        return platform

    # Unknown platform
    return platform


def get_env_info():
    run_lambda = run

    torchscan_str = torchscan.__version__ if TORCHSCAN_AVAILABLE else "N/A"

    if TORCH_AVAILABLE:
        torch_str = torch.__version__
        cuda_available_str = torch.cuda.is_available()
    else:
        torch_str = cuda_available_str = "N/A"

    return SystemEnv(
        torchscan_version=torchscan_str,
        torch_version=torch_str,
        python_version=".".join(map(str, sys.version_info[:3])),
        is_cuda_available=cuda_available_str,
        cuda_runtime_version=get_running_cuda_version(run_lambda),
        nvidia_gpu_models=get_gpu_info(run_lambda),
        nvidia_driver_version=get_nvidia_driver_version(run_lambda),
        cudnn_version=get_cudnn_version(run_lambda),
        os=get_os(run_lambda),
    )


env_info_fmt = """
TorchScan version: {torchscan_version}
PyTorch version: {torch_version}

OS: {os}

Python version: {python_version}
Is CUDA available: {is_cuda_available}
CUDA runtime version: {cuda_runtime_version}
GPU models and configuration: {nvidia_gpu_models}
Nvidia driver version: {nvidia_driver_version}
cuDNN version: {cudnn_version}
""".strip()


def pretty_str(envinfo):
    def replace_nones(dct, replacement="Could not collect"):
        for key in dct:
            if dct[key] is not None:
                continue
            dct[key] = replacement
        return dct

    def replace_bools(dct, true="Yes", false="No"):
        for key in dct:
            if dct[key] is True:
                dct[key] = true
            elif dct[key] is False:
                dct[key] = false
        return dct

    def maybe_start_on_next_line(string):
        # If `string` is multiline, prepend a \n to it.
        if string is not None and len(string.split("\n")) > 1:
            return "\n{}\n".format(string)
        return string

    mutable_dict = envinfo._asdict()

    # If nvidia_gpu_models is multiline, start on the next line
    mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(envinfo.nvidia_gpu_models)

    # If the machine doesn't have CUDA, report some fields as 'No CUDA'
    dynamic_cuda_fields = [
        "cuda_runtime_version",
        "nvidia_gpu_models",
        "nvidia_driver_version",
    ]
    all_cuda_fields = [*dynamic_cuda_fields, "cudnn_version"]
    all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None for field in dynamic_cuda_fields)
    if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
        for field in all_cuda_fields:
            mutable_dict[field] = "No CUDA"

    # Replace True with Yes, False with No
    mutable_dict = replace_bools(mutable_dict)

    # Replace all None objects with 'Could not collect'
    mutable_dict = replace_nones(mutable_dict)

    return env_info_fmt.format(**mutable_dict)


def get_pretty_env_info():
    """Collects environment information for debugging purposes

    Returns:
        str: environment information
    """
    return pretty_str(get_env_info())


def main():
    print("Collecting environment information...")
    output = get_pretty_env_info()
    print(output)


if __name__ == "__main__":
    main()


================================================
FILE: .github/dependabot.yml
================================================
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file

version: 2
updates:
  - package-ecosystem: "github-actions"
    directory: "/"
    schedule:
      interval: "monthly"
      time: "06:00"
      timezone: "Europe/Paris"
    groups:
      gh-actions:
        patterns:
          - "*"
    reviewers:
      - "frgfm"
    assignees:
      - "frgfm"
  - package-ecosystem: "pip"
    directory: "/"
    schedule:
      interval: "daily"
      time: "06:00"
      timezone: "Europe/Paris"
    reviewers:
      - "frgfm"
    assignees:
      - "frgfm"
    allow:
      - dependency-name: "ruff"
      - dependency-name: "mypy"
      - dependency-name: "pre-commit"


================================================
FILE: .github/labeler.yml
================================================
'module: crawler':
- changed-files:
  - any-glob-to-any-file: torchscan/crawler.py

'module: modules':
- changed-files:
  - any-glob-to-any-file: torchscan/modules/*

'module: process':
- changed-files:
  - any-glob-to-any-file: torchscan/process/*

'module: utils':
- changed-files:
  - any-glob-to-any-file: torchscan/utils.py

'ext: docs':
- changed-files:
  - any-glob-to-any-file: docs/*

'ext: scripts':
- changed-files:
  - any-glob-to-any-file: scripts/*

'ext: tests':
- changed-files:
  - any-glob-to-any-file: tests/*

'topic: ci':
- changed-files:
  - any-glob-to-any-file: .github/*

'topic: docs':
- changed-files:
  - any-glob-to-any-file:
    - README.md
    - CONTRIBUTING.md
    - CODFE_OF_CONDUCT.md
    - CITATION.cff
    - LICENSE

'topic: build':
- changed-files:
  - any-glob-to-any-file:
    - setup.py
    - pyproject.toml

'topic: style':
- changed-files:
  - any-glob-to-any-file: .pre-commit-config.yaml


================================================
FILE: .github/release.yml
================================================
changelog:
  exclude:
    labels:
      - ignore-for-release
  categories:
    - title: Breaking Changes 🛠
      labels:
        - "type: breaking change"
    # NEW FEATURES
    - title: New Features 🚀
      labels:
        - "type: feat"
    # BUG FIXES
    - title: Bug Fixes 🐛
      labels:
        - "type: fix"
    # IMPROVEMENTS
    - title: Improvements
      labels:
        - "type: improvement"
    # MISC
    - title: Miscellaneous
      labels:
        - "type: misc"


================================================
FILE: .github/verify_labels.py
================================================
# Copyright (C) 2022-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

"""
Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py
This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in
'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled,
this script is a no-op.
Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision
with no labeling responsibility, so we don't want to bother them.
"""

from typing import Any, Set, Tuple

import requests

# For a PR to be properly labeled it should have one primary label and one secondary label

# Should specify the type of change
PRIMARY_LABELS = {
    "type: new feature",
    "type: bug",
    "type: enhancement",
    "type: misc",
}

# Should specify what has been modified
SECONDARY_LABELS = {
    "topic: documentation",
    "module: modules",
    "module: process",
    "module: crawler",
    "module: utils",
    "ext: docs",
    "ext: scripts",
    "ext: tests",
    "topic: build",
    "topic: ci",
}

GH_ORG = "frgfm"
GH_REPO = "torch-scan"


def query_repo(cmd: str, *, accept) -> Any:
    response = requests.get(
        f"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}", headers={"Accept": accept}, timeout=5
    )
    return response.json()


def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]:
    # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request
    data = query_repo(f"pulls/{pr_number}", accept="application/vnd.github.v3+json")
    merger = data.get("merged_by", {}).get("login")
    labels = {label["name"] for label in data["labels"]}
    return merger, labels


def main(args):
    merger, labels = get_pr_merger_and_labels(args.pr)
    is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels))
    if isinstance(merger, str) and not is_properly_labeled:
        print(f"@{merger}")


def parse_args():
    import argparse

    parser = argparse.ArgumentParser(
        description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("pr", type=int, help="PR number")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    main(args)


================================================
FILE: .github/workflows/builds.yml
================================================
name: builds

on:
  push:
    branches: main
  pull_request:
    branches: main

jobs:
  build:
    runs-on: ${{ matrix.os }}
    strategy:
      fail-fast: false
      matrix:
        os: [ubuntu-latest, macos-latest, windows-latest]
        python: [3.8, 3.9, '3.10', 3.11, 3.12]
        exclude:
          - os: macos-latest
            python: 3.8
          - os: macos-latest
            python: 3.9
          - os: macos-latest
            python: '3.10'
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Install package
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e .
      - name: Import package
        run: python -c "import torchscan; print(torchscan.__version__)"

  pypi:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: 3.11
          architecture: x64
      - name: Install dependencies
        run: |
          python -m pip install --upgrade uv
          uv pip install --system setuptools wheel twine --upgrade
      - run: |
          python setup.py sdist bdist_wheel
          twine check dist/*

  conda:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: conda-incubator/setup-miniconda@v3
        with:
          auto-update-conda: true
          python-version: "3.11"
      - name: Install dependencies
        shell: bash -el {0}
        run: conda install -y conda-build conda-verify
      - name: Build conda
        shell: bash -el {0}
        run: |
          python setup.py sdist
          mkdir conda-dist
          conda env list
          conda build .conda/ -c pytorch --output-folder conda-dist
          ls -l conda-dist/noarch/*tar.bz2


================================================
FILE: .github/workflows/doc-status.yml
================================================
name: GH-Pages Status
on:
  page_build

jobs:
  see-page-build-payload:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/setup-python@v5
        with:
          python-version: 3.11
          architecture: x64
      - name: check status
        run: |
          import os
          status, errormsg = os.getenv('STATUS'), os.getenv('ERROR')
          if status != 'built': raise AssertionError(f"There was an error building the page on GitHub pages.\n\nStatus: {status}\n\nError messsage: {errormsg}")
        shell: python
        env:
          STATUS: ${{ github.event.build.status }}
          ERROR: ${{ github.event.build.error.message }}


================================================
FILE: .github/workflows/docs.yml
================================================
name: docs
on:
  push:
    branches: main

jobs:
  docs-deploy:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
        python: [3.9]
    steps:
      - uses: actions/checkout@v4
        with:
          persist-credentials: false
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Install dependencies
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e ".[docs]"

      - name: Build documentation
        run: cd docs && bash build.sh

      - name: Documentation sanity check
        run: test -e docs/build/index.html || exit

      - name: Install SSH Client 🔑
        uses: webfactory/ssh-agent@v0.9.0
        with:
          ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }}

      - name: Deploy to Github Pages
        uses: JamesIves/github-pages-deploy-action@v4
        with:
          BRANCH: gh-pages
          FOLDER: 'docs/build'
          COMMIT_MESSAGE: '[skip ci] Documentation updates'
          CLEAN: true
          SSH: true


================================================
FILE: .github/workflows/pr-labels.yml
================================================
name: pr-labels

on:
  pull_request:
    branches: main
    types: closed

jobs:
  is-properly-labeled:
    if: github.event.pull_request.merged == true
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
      - name: Install requests
        run: pip install requests
      - name: Process commit and find merger responsible for labeling
        id: commit
        run: echo "::set-output name=merger::$(python .github/verify_labels.py ${{ github.event.pull_request.number }})"
      - name: Comment PR
        uses: actions/github-script@7.0.1
        if: ${{ steps.commit.outputs.merger != '' }}
        with:
          github-token: ${{ secrets.GITHUB_TOKEN }}
          script: |
            const { issue: { number: issue_number }, repo: { owner, repo }  } = context;
            github.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/frgfm/torch-cam/blob/main/.github/verify_labels.py' });


================================================
FILE: .github/workflows/publish.yml
================================================
name: publish

on:
  release:
    types: [published]

jobs:
  pypi:
    if: "!github.event.release.prerelease"
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: 3.11
          architecture: x64
      - name: Install dependencies
        run: |
          python -m pip install --upgrade uv
          uv pip install --system setuptools wheel twine --upgrade
      - name: Build and publish
        env:
          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
        run: |
          echo "BUILD_VERSION=${GITHUB_REF#refs/*/}" | cut -c 2- >> $GITHUB_ENV
          python setup.py sdist bdist_wheel
          twine check dist/*
          twine upload dist/*

  pypi-check:
    if: "!github.event.release.prerelease"
    runs-on: ubuntu-latest
    needs: pypi
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: 3.11
          architecture: x64
      - name: Install package
        run: |
          python -m pip install --upgrade uv
          uv pip install --system torchscan
          python -c "import torchscan; print(torchscan.__version__)"

  conda:
    if: "!github.event.release.prerelease"
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - name: Miniconda setup
        uses: conda-incubator/setup-miniconda@v3
        with:
          auto-update-conda: true
          python-version: 3.11
      - name: Install dependencies
        shell: bash -el {0}
        run: conda install -y conda-build conda-verify anaconda-client
      - name: Build and publish
        shell: bash -el {0}
        env:
          ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }}
        run: |
          echo "BUILD_VERSION=${GITHUB_REF#refs/*/}" | cut -c 2- >> $GITHUB_ENV
          python setup.py sdist
          mkdir conda-dist
          conda build .conda/ -c pytorch --output-folder conda-dist
          ls -l conda-dist/noarch/*tar.bz2
          anaconda upload conda-dist/noarch/*tar.bz2

  conda-check:
    if: "!github.event.release.prerelease"
    runs-on: ubuntu-latest
    needs: conda
    steps:
      - name: Miniconda setup
        uses: conda-incubator/setup-miniconda@v3
        with:
          auto-update-conda: true
          python-version: 3.11
          auto-activate-base: true
      - name: Install package
        shell: bash -el {0}
        run: |
          conda install -c frgfm torchscan
          python -c "import torchscan; print(torchscan.__version__)"


================================================
FILE: .github/workflows/pull_requests.yml
================================================
name: pull_requests

on:
  pull_request:
    branches: main

jobs:
  docs:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: 3.9
          architecture: x64
      - name: Install dependencies
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e ".[docs]"

      - name: Build documentation
        run: cd docs && bash build.sh

      - name: Documentation sanity check
        run: test -e docs/build/index.html || exit

  triage:
    permissions:
      contents: read
      pull-requests: write
    runs-on: ubuntu-latest
    steps:
    - uses: actions/labeler@v5
      with:
        repo-token: "${{ secrets.GITHUB_TOKEN }}"


================================================
FILE: .github/workflows/style.yml
================================================
name: style

on:
  push:
    branches: main
  pull_request:
    branches: main

jobs:
  ruff:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
        python: [3.11]
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Run ruff
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e '.[quality]'
          ruff --version
          ruff check --diff .

  mypy:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
        python: [3.11]
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Run mypy
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e '.[quality]'
          mypy --version
          mypy

  ruff-format:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
        python: [3.11]
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Run ruff
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e '.[quality]'
          ruff --version
          ruff format --check --diff .

  precommit-hooks:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
        python: [3.11]
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Run pre-commit hooks
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e '.[quality]'
          git checkout -b temp
          pre-commit install
          pre-commit --version
          pre-commit run --all-files


================================================
FILE: .github/workflows/tests.yml
================================================
name: tests

on:
  push:
    branches: main
  pull_request:
    branches: main

jobs:
  pytest:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
        python: [3.11]
    steps:
      - uses: actions/checkout@v4
        with:
          persist-credentials: false
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python }}
          architecture: x64
      - name: Install dependencies
        run: |
          python -m pip install --upgrade uv
          uv pip install --system -e ".[test]" --upgrade
      - name: Run unittests
        run: pytest --cov=torchscan --cov-report xml tests/
      - uses: actions/upload-artifact@v4
        with:
          name: coverage-reports
          path: ./coverage.xml

  codecov-upload:
    runs-on: ubuntu-latest
    needs: pytest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/download-artifact@v4
      - name: Upload coverage to Codecov
        uses: codecov/codecov-action@v5
        with:
          token: ${{ secrets.CODECOV_TOKEN }}
          flags: unittests
          directory: ./coverage-reports
          fail_ci_if_error: true

  headers:
    runs-on: ${{ matrix.os }}
    strategy:
      matrix:
        os: [ubuntu-latest]
    steps:
      - uses: actions/checkout@v4
        with:
          persist-credentials: false
      - name: Check the headers
        uses: frgfm/validate-python-headers@main
        with:
          license: 'Apache-2.0'
          owner: 'François-Guillaume Fernandez'
          starting-year: 2020
          folders: 'torchscan,scripts,docs,.github'
          ignores: 'version.py,__init__.py'


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
conda-dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Package version
torchscan/version.py
# Conda distribution
conda-dist/


================================================
FILE: .pre-commit-config.yaml
================================================
default_language_version:
    python: python3.11
repos:
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.5.0
    hooks:
      - id: check-added-large-files
      - id: check-ast
      - id: check-case-conflict
      - id: check-json
      - id: check-merge-conflict
      - id: check-symlinks
      - id: check-toml
      - id: check-xml
      - id: check-yaml
        exclude: .conda
      - id: debug-statements
        language_version: python3
      - id: end-of-file-fixer
      - id: no-commit-to-branch
        args: ['--branch', 'main']
      - id: requirements-txt-fixer
      - id: trailing-whitespace
  - repo: https://github.com/charliermarsh/ruff-pre-commit
    rev: 'v0.6.4'
    hooks:
      - id: ruff
        args:
          - --fix
      - id: ruff-format


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct

## Our Pledge

We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.

## Our Standards

Examples of behavior that contributes to a positive environment for our
community include:

* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
  and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
  overall community

Examples of unacceptable behavior include:

* The use of sexualized language or imagery, and sexual attention or
  advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
  address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
  professional setting

## Enforcement Responsibilities

Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.

Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.

## Scope

This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
fg-feedback@protonmail.com.
All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the
reporter of any incident.

## Enforcement Guidelines

Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:

### 1. Correction

**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.

**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.

### 2. Warning

**Community Impact**: A violation through a single incident or series
of actions.

**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.

### 3. Temporary Ban

**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.

**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.

### 4. Permanent Ban

**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior,  harassment of an
individual, or aggression toward or disparagement of classes of individuals.

**Consequence**: A permanent ban from any sort of public interaction within
the community.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.

Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to torchscan

Everything you need to know to contribute efficiently to the project.

Whatever the way you wish to contribute to the project, please respect the [code of conduct](CODE_OF_CONDUCT.md).


## Codebase structure

- [torchscan](https://github.com/frgfm/torch-scan/blob/main/torchscan) - The actual torchscan library
- [tests](https://github.com/frgfm/torch-scan/blob/main/tests) - Python unit tests
- [docs](https://github.com/frgfm/torch-scan/blob/main/docs) - Sphinx documentation building
- [scripts](https://github.com/frgfm/torch-scan/blob/main/scripts) - Example and utilities scripts



## Continuous Integration

This project uses the following integrations to ensure proper codebase maintenance:

- [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage
- [Codacy](https://www.codacy.com/) - analyzes commits for code quality
- [Codecov](https://codecov.io/) - reports back coverage results

As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code.


## Feedback

### Feature requests & bug report

Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/frgfm/torch-scan/issues).

First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in.

### Questions

If you are wondering how to do something with TorchScan, or a more general question, you should consider checking out Github [discussions](https://github.com/frgfm/torch-scan/discussions). See it as a Q&A forum, or the TorchScan-specific StackOverflow!



## Submitting a Pull Request

### Preparing your local branch

1 - Fork this [repository](https://github.com/frgfm/torch-scan) by clicking on the "Fork" button at the top right of the page. This will create a copy of the project under your GitHub account (cf. [Fork a repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo)).

2 - [Clone your fork](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) to your local disk and set the upstream to this repo
```shell
git clone git@github.com:<YOUR_GITHUB_ACCOUNT>/torch-scan.git
cd torch-scan
git remote add upstream https://github.com/frgfm/torch-scan.git
```

3 - You should not work on the `main` branch, so let's create a new one
```shell
git checkout -b a-short-description
```

4 - You only have to set your development environment now. First uninstall any existing installation of the library with `pip uninstall torch-scan`, then:
```shell
pip install -e ".[dev]"
pre-commit install
```

### Developing your feature

#### Commits

- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later.
- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/)

#### Unit tests

In order to run the same unit tests as the CI workflows, you can run unittests locally:

```shell
make test
```

#### Code quality

The CI will also run some sanity checks (header format, dependency consistency, etc.), which you can run as follows:

```shell
make quality
```

This will read `pyproject.toml` and run:
- lint checking, formatting ([ruff](https://docs.astral.sh/ruff/))
- type annotation checking ([mypy](https://github.com/python/mypy))

You can apply automatic fix to most of those by running:

```shell
make style
```

### Submit your modifications

Push your last modifications to your remote branch
```shell
git push -u origin a-short-description
```

Then [open a Pull Request](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) from your fork's branch. Follow the instructions of the Pull Request template and then click on "Create a pull request".


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: Makefile
================================================
# this target runs checks on all files
quality:
	ruff format --check .
	ruff check .
	mypy

# this target runs checks on all files and potentially modifies some of them
style:
	ruff format .
	ruff check --fix .

# Run tests for the library
test:
	pytest --cov=torchscan tests/

# Build documentation for current version
single-docs:
	sphinx-build docs/source docs/_build -a

# Check that docs can build
full-docs:
	cd docs && bash build.sh


================================================
FILE: README.md
================================================
<p align="center">
  <img src="https://github.com/frgfm/torch-scan/releases/download/v0.1.1/logo_text.png" width="30%">
</p>

<p align="center">
  <a href="https://github.com/frgfm/torch-scan/actions/workflows/builds.yml">
    <img alt="CI Status" src="https://img.shields.io/github/actions/workflow/status/frgfm/torch-scan/builds.yml?branch=main&label=CI&logo=github&style=flat-square">
  </a>
  <a href="https://github.com/astral-sh/ruff">
    <img src="https://img.shields.io/badge/Linter-Ruff-FCC21B?style=flat-square&logo=ruff&logoColor=white" alt="ruff">
  </a>
  <a href="https://github.com/astral-sh/ruff">
    <img src="https://img.shields.io/badge/Formatter-Ruff-FCC21B?style=flat-square&logo=Python&logoColor=white" alt="ruff">
  </a>
  <a href="https://www.codacy.com/gh/frgfm/torch-scan/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=frgfm/torch-scan&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/9dc68e8bfce34d9dbc8b44a350e9adc7"/></a>
  <a href="https://codecov.io/gh/frgfm/torch-scan">
    <img src="https://img.shields.io/codecov/c/github/frgfm/torch-scan.svg?logo=codecov&style=flat-square&label=Coverage" alt="Test coverage percentage">
  </a>
</p>
<p align="center">
  <a href="https://pypi.org/project/torchscan/">
    <img src="https://img.shields.io/pypi/v/torchscan.svg?logo=PyPI&logoColor=fff&style=flat-square&label=PyPI" alt="PyPi Version">
  </a>
  <a href="https://anaconda.org/frgfm/torchscan">
    <img src="https://img.shields.io/conda/v/frgfm/torchscan.svg?logo=anaconda&label=Conda&logoColor=fff&style=flat-square" alt="Conda Version">
  </a>
  <img src="https://img.shields.io/pypi/pyversions/torchscan.svg?logo=Python&label=Python&logoColor=fff&style=flat-square" alt="pyversions">
  <a href="https://github.com/frgfm/torch-scan/blob/main/LICENSE">
    <img src="https://img.shields.io/github/license/frgfm/torch-scan.svg?label=License&logoColor=fff&style=flat-square" alt="License">
  </a>
</p>
<p align="center">
  <a href="https://frgfm.github.io/torch-scan">
    <img src="https://img.shields.io/github/actions/workflow/status/frgfm/torch-scan/docs.yml?branch=main&label=Documentation&logo=read-the-docs&logoColor=white&style=flat-square" alt="Documentation Status">
  </a>
</p>


The very useful [summary](https://www.tensorflow.org/api_docs/python/tf/keras/Model#summary) method of `tf.keras.Model` but for PyTorch, with more useful information.


## Quick Tour

### Inspecting your PyTorch architecture

Similarly to the `torchsummary` implementation, `torchscan` brings useful module information into readable format. For nested complex architectures, you can use a maximum depth of display as follows:

```python
from torchvision.models import densenet121
from torchscan import summary

model = densenet121().eval().cuda()
summary(model, (3, 224, 224), max_depth=2)
```

which would yield

```shell
__________________________________________________________________________________________
Layer                        Type                  Output Shape              Param #
==========================================================================================
densenet                     DenseNet              (-1, 1000)                0
├─features                   Sequential            (-1, 1024, 7, 7)          0
|    └─conv0                 Conv2d                (-1, 64, 112, 112)        9,408
|    └─norm0                 BatchNorm2d           (-1, 64, 112, 112)        257
|    └─relu0                 ReLU                  (-1, 64, 112, 112)        0
|    └─pool0                 MaxPool2d             (-1, 64, 56, 56)          0
|    └─denseblock1           _DenseBlock           (-1, 256, 56, 56)         338,316
|    └─transition1           _Transition           (-1, 128, 28, 28)         33,793
|    └─denseblock2           _DenseBlock           (-1, 512, 28, 28)         930,072
|    └─transition2           _Transition           (-1, 256, 14, 14)         133,121
|    └─denseblock3           _DenseBlock           (-1, 1024, 14, 14)        2,873,904
|    └─transition3           _Transition           (-1, 512, 7, 7)           528,385
|    └─denseblock4           _DenseBlock           (-1, 1024, 7, 7)          2,186,272
|    └─norm5                 BatchNorm2d           (-1, 1024, 7, 7)          4,097
├─classifier                 Linear                (-1, 1000)                1,025,000
==========================================================================================
Trainable params: 7,978,856
Non-trainable params: 0
Total params: 7,978,856
------------------------------------------------------------------------------------------
Model size (params + buffers): 30.76 Mb
Framework & CUDA overhead: 423.57 Mb
Total RAM usage: 454.32 Mb
------------------------------------------------------------------------------------------
Floating Point Operations on forward: 5.74 GFLOPs
Multiply-Accumulations on forward: 2.87 GMACs
Direct memory accesses on forward: 2.90 GDMAs
__________________________________________________________________________________________
```

Results are aggregated to the selected depth for improved readability.

For reference, here are explanations of a few acronyms:

- **FLOPs**: floating-point operations (not to be confused with FLOPS which is FLOPs per second)
- **MACs**: mutiply-accumulate operations (cf. [wikipedia](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation))
- **DMAs**: direct memory accesses (many argue that it is more relevant than FLOPs or MACs to compare model inference speeds cf. [wikipedia](https://en.wikipedia.org/wiki/Direct_memory_access))



Additionally, for highway nets (models without multiple branches / skip connections), `torchscan` supports receptive field estimation.

```python
from torchvision.models import vgg16
from torchscan import summary

model = vgg16().eval().cuda()
summary(model, (3, 224, 224), receptive_field=True, max_depth=0)
```

which will add the layer's receptive field (relatively to the last convolutional layer) to the summary.


## Setup

Python 3.8 (or newer) and [pip](https://pip.pypa.io/en/stable/)/[conda](https://docs.conda.io/en/latest/miniconda.html) are required to install Torchscan.

### Stable release

You can install the last stable release of the package using [pypi](https://pypi.org/project/torch-scan/) as follows:

```shell
pip install torchscan
```

or using [conda](https://anaconda.org/frgfm/torchscan):

```shell
conda install -c frgfm torchscan
```

### Developer installation

Alternatively, if you wish to use the latest features of the project that haven't made their way to a release yet, you can install the package from source:

```shell
git clone https://github.com/frgfm/torch-scan.git
pip install -e torch-scan/.
```


## Benchmark

Below are the results for classification models supported by `torchvision` for a single image with 3 color channels of size `224x224` (apart from  `inception_v3`   which uses `299x299`).

| Model              | Params (M) | FLOPs (G) | MACs (G) | DMAs (G) | RF   |
| ------------------ | ---------- | --------- | -------- | -------- | ---- |
| alexnet            | 61.1       | 1.43      | 0.71     | 0.72     | 195  |
| googlenet          | 6.62       | 3.01      | 1.51     | 1.53     | --   |
| vgg11              | 132.86     | 15.23     | 7.61     | 7.64     | 150  |
| vgg11_bn           | 132.87     | 15.26     | 7.63     | 7.66     | 150  |
| vgg13              | 133.05     | 22.63     | 11.31    | 11.35    | 156  |
| vgg13_bn           | 133.05     | 22.68     | 11.33    | 11.37    | 156  |
| vgg16              | 138.36     | 30.96     | 15.47    | 15.52    | 212  |
| vgg16_bn           | 138.37     | 31.01     | 15.5     | 15.55    | 212  |
| vgg19              | 143.67     | 39.28     | 19.63    | 19.69    | 268  |
| vgg19_bn           | 143.68     | 39.34     | 19.66    | 19.72    | 268  |
| resnet18           | 11.69      | 3.64      | 1.82     | 1.84     | --   |
| resnet34           | 21.8       | 7.34      | 3.67     | 3.7      | --   |
| resnet50           | 25.56      | 8.21      | 4.11     | 4.15     | --   |
| resnet101          | 44.55      | 15.66     | 7.83     | 7.9      | --   |
| resnet152          | 60.19      | 23.1      | 11.56    | 11.65    | --   |
| inception_v3       | 27.16      | 11.45     | 5.73     | 5.76     | --   |
| squeezenet1_0      | 1.25       | 1.64      | 0.82     | 0.83     | --   |
| squeezenet1_1      | 1.24       | 0.7       | 0.35     | 0.36     | --   |
| wide_resnet50_2    | 68.88      | 22.84     | 11.43    | 11.51    | --   |
| wide_resnet101_2   | 126.89     | 45.58     | 22.8     | 22.95    | --   |
| densenet121        | 7.98       | 5.74      | 2.87     | 2.9      | --   |
| densenet161        | 28.68      | 15.59     | 7.79     | 7.86     | --   |
| densenet169        | 14.15      | 6.81      | 3.4      | 3.44     | --   |
| densenet201        | 20.01      | 8.7       | 4.34     | 4.39     | --   |
| resnext50_32x4d    | 25.03      | 8.51      | 4.26     | 4.3      | --   |
| resnext101_32x8d   | 88.79      | 32.93     | 16.48    | 16.61    | --   |
| mobilenet_v2       | 3.5        | 0.63      | 0.31     | 0.32     | --   |
| shufflenet_v2_x0_5 | 1.37       | 0.09      | 0.04     | 0.05     | --   |
| shufflenet_v2_x1_0 | 2.28       | 0.3       | 0.15     | 0.15     | --   |
| shufflenet_v2_x1_5 | 3.5        | 0.6       | 0.3      | 0.31     | --   |
| shufflenet_v2_x2_0 | 7.39       | 1.18      | 0.59     | 0.6      | --   |
| mnasnet0_5         | 2.22       | 0.22      | 0.11     | 0.12     | --   |
| mnasnet0_75        | 3.17       | 0.45      | 0.23     | 0.24     | --   |
| mnasnet1_0         | 4.38       | 0.65      | 0.33     | 0.34     | --   |
| mnasnet1_3         | 6.28       | 1.08      | 0.54     | 0.56     | --   |

The above results were produced using the `scripts/benchmark.py` script.

*Note: receptive field computation is currently only valid for highway nets.*



## What else

### Documentation

The full package documentation is available [here](https://frgfm.github.io/torch-scan/) for detailed specifications.


### Example script

An example script is provided for you to benchmark torchvision models using the library:

```shell
python scripts/benchmark.py
```


## Credits

This project is developed and maintained by the repo owner, but the implementation was inspired or helped by the following contributions:

- [Pytorch summary](https://github.com/sksq96/pytorch-summary): existing PyTorch porting of `tf.keras.Model.summary`
- [Torchstat](https://github.com/Swall0w/torchstat): another module inspection tool
- [Flops counter Pytorch](https://github.com/sovrasov/flops-counter.pytorch): operation counter tool
- [THOP](https://github.com/Lyken17/pytorch-OpCounter): PyTorch Op counter
- Number of operations and memory estimation articles by [Matthijs Hollemans](https://machinethink.net/blog/how-fast-is-my-model/), and [Sicara](https://www.sicara.ai/blog/2019-28-10-deep-learning-memory-usage-and-pytorch-optimization-tricks)
- [Pruning Convolutional Neural Networks for Resource Efficient Inference](https://arxiv.org/abs/1611.06440)


## Citation

If you wish to cite this project, feel free to use this [BibTeX](http://www.bibtex.org/) reference:

```bibtex
@misc{torchscan2020,
    title={Torchscan: meaningful module insights},
    author={François-Guillaume Fernandez},
    year={2020},
    month={March},
    publisher = {GitHub},
    howpublished = {\url{https://github.com/frgfm/torch-scan}}
}
```


## Contributing

Any sort of contribution is greatly appreciated!

You can find a short guide in [`CONTRIBUTING`](CONTRIBUTING.md) to help grow this project!



## License

Distributed under the Apache 2.0 License. See [`LICENSE`](LICENSE) for more information.


================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS    ?=
SPHINXBUILD   ?= sphinx-build
SOURCEDIR     = source
BUILDDIR      = build

# Put it first so that "make" without argument is like "make help".
help:
	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


================================================
FILE: docs/README.md
================================================
# Changing the documentation

The documentation of this project is built using `sphinx`. In order to install all the build dependencies, run the following command from the root folder of the repository:
```shell
pip install -e ".[docs]"
```

---
**NOTE**

You are only generating the documentation to inspect it locally. Only the source files are pushed to the remote repository, the documentation will be built automatically by the CI.

---

## Build the documentation

### Latest version

In most cases, you will only be changing the documentation of the latest version (dev version). In this case, you can build the documentation (the HTML files) with the following command:

```shell
sphinx-build docs/source docs/_build -a
```

Then open `docs/_build/index.html` in your web browser to navigate in it.


### Multi-version documentation

In rare cases, you might want to modify the documentation for other versions. You will then have to build the documentation for the multiple versions of the package, which you can do by running this command from the `docs` folder:
```shell
bash build.sh
```


================================================
FILE: docs/build.sh
================================================
function deploy_doc(){
    if [ ! -z "$1" ]
    then
        git checkout $1
    fi
    COMMIT=$(git rev-parse --short HEAD)
    echo "Creating doc at commit" $COMMIT "and pushing to folder $2"
    pip install -U ..
    if [ ! -z "$2" ]
    then
        if [ "$2" == "latest" ]; then
            echo "Pushing main"
            sphinx-build source build/$2 -a
        elif [ -d build/$2 ]; then
            echo "Directory" $2 "already exists"
        else
            echo "Pushing version" $2
            cp -r _static source/ && cp _conf.py source/conf.py
            sphinx-build source build/$2 -a
        fi
    else
        echo "Pushing stable"
        cp -r _static source/ && cp _conf.py source/conf.py
        sphinx-build source build -a
    fi
    git checkout source/ && git clean -f source/
}

# exit when any command fails
set -e
# You can find the commit for each tag on https://github.com/frgfm/torch-scan/tags
if [ -d build ]; then rm -Rf build; fi
mkdir build
cp -r source/_static .
cp source/conf.py _conf.py
git fetch --all --tags --unshallow
deploy_doc "" latest
deploy_doc "7ac9c839" v0.1.0
deploy_doc "900eb166" v0.1.1
deploy_doc "29fa4ed1" # v0.1.2 Latest stable release
rm -rf _build _static


================================================
FILE: docs/make.bat
================================================
@ECHO OFF

pushd %~dp0

REM Command file for Sphinx documentation

if "%SPHINXBUILD%" == "" (
	set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build

if "%1" == "" goto help

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
	echo.
	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
	echo.installed, then set the SPHINXBUILD environment variable to point
	echo.to the full path of the 'sphinx-build' executable. Alternatively you
	echo.may add the Sphinx directory to PATH.
	echo.
	echo.If you don't have Sphinx installed, grab it from
	echo.http://sphinx-doc.org/
	exit /b 1
)

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd


================================================
FILE: docs/source/_static/css/custom.css
================================================
h1 {
    font-size: 200%;
}

/* Github button */

.github-repo {
    display: flex;
    justify-content: center;
}

/* Version control */

.version-button {
    color: gray;
    border: none;
    padding: 5px;
    font-size: 15px;
    cursor: pointer;
}

.version-button:hover, .version-button:focus {
    color: white;
    background-color: gray;
}

.version-dropdown {
    display: none;
    min-width: 160px;
    overflow: auto;
    font-size: 15px;
}

.version-dropdown a {
    color: gray;
    padding: 3px 4px;
    text-decoration: none;
    display: block;
}

.version-dropdown a:hover {
    color: white;
    background-color: gray;
}

.version-show {
    display: block;
}


================================================
FILE: docs/source/_static/js/custom.js
================================================
// Based on https://github.com/huggingface/transformers/blob/master/docs/source/_static/js/custom.js


// These two things need to be updated at each release for the version selector.
// Last stable version
const stableVersion = "v0.1.2"
// Dictionary doc folder to label. The last stable version should have an empty key.
const versionMapping = {
    "latest": "latest",
    "": "v0.1.2 (stable)",
    "v0.1.1": "v0.1.1",
    "v0.1.0": "v0.1.0",
}

function addGithubButton() {
    const div = `
        <div class="github-repo">
            <a
                class="github-button"
                href="https://github.com/frgfm/torch-scan"
                data-size="large"
                data-show-count="true"
                aria-label="Star frgfm/torch-scan on GitHub">Star</a>
        </div>
    `;
    document.querySelector(".sidebar-brand").insertAdjacentHTML('afterend', div);
}

function addVersionControl() {
    // To grab the version currently in view, we parse the url
    const parts = location.toString().split('/');
    let versionIndex = parts.length - 2;
    // Index page may not have a last part with filename.html so we need to go up
    if (parts[parts.length - 1] != "" && ! parts[parts.length - 1].match(/\.html$|^search.html?/)) {
        versionIndex = parts.length - 1;
    }
    const version = parts[versionIndex];

    // Menu with all the links,
    const versionMenu = document.createElement("div");

    const htmlLines = [];
    for (const [key, value] of Object.entries(versionMapping)) {
        let baseUrlIndex = (version == "torch-scan") ? versionIndex + 1: versionIndex;
        var urlParts = parts.slice(0, baseUrlIndex);
        if (key != "") {
            urlParts = urlParts.concat([key]);
        }
        urlParts = urlParts.concat(parts.slice(versionIndex+1));
        htmlLines.push(`<a href="${urlParts.join('/')}">${value}</a>`);
    }

    versionMenu.classList.add("version-dropdown");
    versionMenu.innerHTML = htmlLines.join('\n');

    // Button for version selection
    const versionButton = document.createElement("div");
    versionButton.classList.add("version-button");
    let label = (version == "torch-scan") ? stableVersion : version
    versionButton.innerText = label.concat(" ▼");

    // Toggle the menu when we click on the button
    versionButton.addEventListener("click", () => {
        versionMenu.classList.toggle("version-show");
    });

    // Hide the menu when we click elsewhere
    window.addEventListener("click", (event) => {
        if (event.target != versionButton){
            versionMenu.classList.remove('version-show');
        }
    });

    // Container
    const div = document.createElement("div");
    div.appendChild(versionButton);
    div.appendChild(versionMenu);
    div.style.paddingTop = '5px';
    div.style.paddingBottom = '5px';
    div.style.display = 'block';
    div.style.textAlign = 'center';

    const scrollDiv = document.querySelector(".sidebar-brand");
    scrollDiv.insertBefore(div, scrollDiv.children[1]);
}

/*!
 * github-buttons v2.2.10
 * (c) 2019 なつき
 * @license BSD-2-Clause
 */
/**
 * modified to run programmatically
 */
function parseGithubButtons (){"use strict";var e=window.document,t=e.location,o=window.encodeURIComponent,r=window.decodeURIComponent,n=window.Math,a=window.HTMLElement,i=window.XMLHttpRequest,l="https://unpkg.com/github-buttons@2.2.10/dist/buttons.html",c=i&&i.prototype&&"withCredentials"in i.prototype,d=c&&a&&a.prototype.attachShadow&&!a.prototype.attachShadow.prototype,s=function(e,t,o){e.addEventListener?e.addEventListener(t,o):e.attachEvent("on"+t,o)},u=function(e,t,o){e.removeEventListener?e.removeEventListener(t,o):e.detachEvent("on"+t,o)},h=function(e,t,o){var r=function(n){return u(e,t,r),o(n)};s(e,t,r)},f=function(e,t,o){var r=function(n){if(t.test(e.readyState))return u(e,"readystatechange",r),o(n)};s(e,"readystatechange",r)},p=function(e){return function(t,o,r){var n=e.createElement(t);if(o)for(var a in o){var i=o[a];null!=i&&(null!=n[a]?n[a]=i:n.setAttribute(a,i))}if(r)for(var l=0,c=r.length;l<c;l++){var d=r[l];n.appendChild("string"==typeof d?e.createTextNode(d):d)}return n}},g=p(e),b=function(e){var t;return function(){t||(t=1,e.apply(this,arguments))}},m="body{margin:0}a{color:#24292e;text-decoration:none;outline:0}.octicon{display:inline-block;vertical-align:text-top;fill:currentColor}.widget{ display:inline-block;overflow:hidden;font-family:-apple-system, BlinkMacSystemFont, \"Segoe UI\", Helvetica, Arial, sans-serif;font-size:0;white-space:nowrap;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.btn,.social-count{display:inline-block;height:14px;padding:2px 5px;font-size:11px;font-weight:600;line-height:14px;vertical-align:bottom;cursor:pointer;border:1px solid #c5c9cc;border-radius:0.25em}.btn{background-color:#eff3f6;background-image:-webkit-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:-moz-linear-gradient(top, #fafbfc, #eff3f6 90%);background-image:linear-gradient(180deg, #fafbfc, #eff3f6 90%);background-position:-1px -1px;background-repeat:repeat-x;background-size:110% 110%;border-color:rgba(27,31,35,0.2);-ms-filter:\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFFAFBFC', endColorstr='#FFEEF2F5')}.btn:active{background-color:#e9ecef;background-image:none;border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);box-shadow:inset 0 0.15em 0.3em rgba(27,31,35,0.15)}.btn:focus,.btn:hover{background-color:#e6ebf1;background-image:-webkit-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:-moz-linear-gradient(top, #f0f3f6, #e6ebf1 90%);background-image:linear-gradient(180deg, #f0f3f6, #e6ebf1 90%);border-color:#a5a9ac;border-color:rgba(27,31,35,0.35);-ms-filter:\"progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')\";*filter:progid:DXImageTransform.Microsoft.Gradient(startColorstr='#FFF0F3F6', endColorstr='#FFE5EAF0')}.social-count{position:relative;margin-left:5px;background-color:#fff}.social-count:focus,.social-count:hover{color:#0366d6}.social-count b,.social-count i{position:absolute;top:50%;left:0;display:block;width:0;height:0;margin:-4px 0 0 -4px;border:solid transparent;border-width:4px 4px 4px 0;_line-height:0;_border-top-color:red !important;_border-bottom-color:red !important;_border-left-color:red !important;_filter:chroma(color=red)}.social-count b{border-right-color:#c5c9cc}.social-count i{margin-left:-3px;border-right-color:#fff}.lg .btn,.lg .social-count{height:16px;padding:5px 10px;font-size:12px;line-height:16px}.lg .social-count{margin-left:6px}.lg .social-count b,.lg .social-count i{margin:-5px 0 0 -5px;border-width:5px 5px 5px 0}.lg .social-count i{margin-left:-4px}\n",v={"mark-github":{width:16,height:16,path:'<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"/>'},eye:{width:16,height:16,path:'<path fill-rule="evenodd" d="M8.06 2C3 2 0 8 0 8s3 6 8.06 6C13 14 16 8 16 8s-3-6-7.94-6zM8 12c-2.2 0-4-1.78-4-4 0-2.2 1.8-4 4-4 2.22 0 4 1.8 4 4 0 2.22-1.78 4-4 4zm2-4c0 1.11-.89 2-2 2-1.11 0-2-.89-2-2 0-1.11.89-2 2-2 1.11 0 2 .89 2 2z"/>'},star:{width:14,height:16,path:'<path fill-rule="evenodd" d="M14 6l-4.9-.64L7 1 4.9 5.36 0 6l3.6 3.26L2.67 14 7 11.67 11.33 14l-.93-4.74L14 6z"/>'},"repo-forked":{width:10,height:16,path:'<path fill-rule="evenodd" d="M8 1a1.993 1.993 0 0 0-1 3.72V6L5 8 3 6V4.72A1.993 1.993 0 0 0 2 1a1.993 1.993 0 0 0-1 3.72V6.5l3 3v1.78A1.993 1.993 0 0 0 5 15a1.993 1.993 0 0 0 1-3.72V9.5l3-3V4.72A1.993 1.993 0 0 0 8 1zM2 4.2C1.34 4.2.8 3.65.8 3c0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3 10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2zm3-10c-.66 0-1.2-.55-1.2-1.2 0-.65.55-1.2 1.2-1.2.65 0 1.2.55 1.2 1.2 0 .65-.55 1.2-1.2 1.2z"/>'},"issue-opened":{width:14,height:16,path:'<path fill-rule="evenodd" d="M7 2.3c3.14 0 5.7 2.56 5.7 5.7s-2.56 5.7-5.7 5.7A5.71 5.71 0 0 1 1.3 8c0-3.14 2.56-5.7 5.7-5.7zM7 1C3.14 1 0 4.14 0 8s3.14 7 7 7 7-3.14 7-7-3.14-7-7-7zm1 3H6v5h2V4zm0 6H6v2h2v-2z"/>'},"cloud-download":{width:16,height:16,path:'<path fill-rule="evenodd" d="M9 12h2l-3 3-3-3h2V7h2v5zm3-8c0-.44-.91-3-4.5-3C5.08 1 3 2.92 3 5 1.02 5 0 6.52 0 8c0 1.53 1 3 3 3h3V9.7H3C1.38 9.7 1.3 8.28 1.3 8c0-.17.05-1.7 1.7-1.7h1.3V5c0-1.39 1.56-2.7 3.2-2.7 2.55 0 3.13 1.55 3.2 1.8v1.2H12c.81 0 2.7.22 2.7 2.2 0 2.09-2.25 2.2-2.7 2.2h-2V11h2c2.08 0 4-1.16 4-3.5C16 5.06 14.08 4 12 4z"/>'}},w={},x=function(e,t,o){var r=p(e.ownerDocument),n=e.appendChild(r("style",{type:"text/css"}));n.styleSheet?n.styleSheet.cssText=m:n.appendChild(e.ownerDocument.createTextNode(m));var a,l,d=r("a",{className:"btn",href:t.href,target:"_blank",innerHTML:(a=t["data-icon"],l=/^large$/i.test(t["data-size"])?16:14,a=(""+a).toLowerCase().replace(/^octicon-/,""),{}.hasOwnProperty.call(v,a)||(a="mark-github"),'<svg version="1.1" width="'+l*v[a].width/v[a].height+'" height="'+l+'" viewBox="0 0 '+v[a].width+" "+v[a].height+'" class="octicon octicon-'+a+'" aria-hidden="true">'+v[a].path+"</svg>"),"aria-label":t["aria-label"]||void 0},[" ",r("span",{},[t["data-text"]||""])]);/\.github\.com$/.test("."+d.hostname)?/^https?:\/\/((gist\.)?github\.com\/[^\/?#]+\/[^\/?#]+\/archive\/|github\.com\/[^\/?#]+\/[^\/?#]+\/releases\/download\/|codeload\.github\.com\/)/.test(d.href)&&(d.target="_top"):(d.href="#",d.target="_self");var u,h,g,x,y=e.appendChild(r("div",{className:"widget"+(/^large$/i.test(t["data-size"])?" lg":"")},[d]));/^(true|1)$/i.test(t["data-show-count"])&&"github.com"===d.hostname&&(u=d.pathname.replace(/^(?!\/)/,"/").match(/^\/([^\/?#]+)(?:\/([^\/?#]+)(?:\/(?:(subscription)|(fork)|(issues)|([^\/?#]+)))?)?(?:[\/?#]|$)/))&&!u[6]?(u[2]?(h="/repos/"+u[1]+"/"+u[2],u[3]?(x="subscribers_count",g="watchers"):u[4]?(x="forks_count",g="network"):u[5]?(x="open_issues_count",g="issues"):(x="stargazers_count",g="stargazers")):(h="/users/"+u[1],g=x="followers"),function(e,t){var o=w[e]||(w[e]=[]);if(!(o.push(t)>1)){var r=b(function(){for(delete w[e];t=o.shift();)t.apply(null,arguments)});if(c){var n=new i;s(n,"abort",r),s(n,"error",r),s(n,"load",function(){var e;try{e=JSON.parse(n.responseText)}catch(e){return void r(e)}r(200!==n.status,e)}),n.open("GET",e),n.send()}else{var a=this||window;a._=function(e){a._=null,r(200!==e.meta.status,e.data)};var l=p(a.document)("script",{async:!0,src:e+(/\?/.test(e)?"&":"?")+"callback=_"}),d=function(){a._&&a._({meta:{}})};s(l,"load",d),s(l,"error",d),l.readyState&&f(l,/de|m/,d),a.document.getElementsByTagName("head")[0].appendChild(l)}}}.call(this,"https://api.github.com"+h,function(e,t){if(!e){var n=t[x];y.appendChild(r("a",{className:"social-count",href:t.html_url+"/"+g,target:"_blank","aria-label":n+" "+x.replace(/_count$/,"").replace("_"," ").slice(0,n<2?-1:void 0)+" on GitHub"},[r("b"),r("i"),r("span",{},[(""+n).replace(/\B(?=(\d{3})+(?!\d))/g,",")])]))}o&&o(y)})):o&&o(y)},y=window.devicePixelRatio||1,C=function(e){return(y>1?n.ceil(n.round(e*y)/y*2)/2:n.ceil(e))||0},F=function(e,t){e.style.width=t[0]+"px",e.style.height=t[1]+"px"},k=function(t,r){if(null!=t&&null!=r)if(t.getAttribute&&(t=function(e){for(var t={href:e.href,title:e.title,"aria-label":e.getAttribute("aria-label")},o=["icon","text","size","show-count"],r=0,n=o.length;r<n;r++){var a="data-"+o[r];t[a]=e.getAttribute(a)}return null==t["data-text"]&&(t["data-text"]=e.textContent||e.innerText),t}(t)),d){var a=g("span",{title:t.title||void 0});x(a.attachShadow({mode:"closed"}),t,function(){r(a)})}else{var i=g("iframe",{src:"javascript:0",title:t.title||void 0,allowtransparency:!0,scrolling:"no",frameBorder:0});F(i,[0,0]),i.style.border="none";var c=function(){var a,d=i.contentWindow;try{a=d.document.body}catch(t){return void e.body.appendChild(i.parentNode.removeChild(i))}u(i,"load",c),x.call(d,a,t,function(e){var a=function(e){var t=e.offsetWidth,o=e.offsetHeight;if(e.getBoundingClientRect){var r=e.getBoundingClientRect();t=n.max(t,C(r.width)),o=n.max(o,C(r.height))}return[t,o]}(e);i.parentNode.removeChild(i),h(i,"load",function(){F(i,a)}),i.src=l+"#"+(i.name=function(e){var t=[];for(var r in e){var n=e[r];null!=n&&t.push(o(r)+"="+o(n))}return t.join("&")}(t)),r(i)})};s(i,"load",c),e.body.appendChild(i)}};t.protocol+"//"+t.host+t.pathname===l?x(e.body,function(e){for(var t={},o=e.split("&"),n=0,a=o.length;n<a;n++){var i=o[n];if(""!==i){var l=i.split("=");t[r(l[0])]=null!=l[1]?r(l.slice(1).join("=")):void 0}}return t}(window.name||t.hash.replace(/^#/,""))):function(t){if(/m/.test(e.readyState)||!/g/.test(e.readyState)&&!e.documentElement.doScroll)setTimeout(t);else if(e.addEventListener){var o=b(t);h(e,"DOMContentLoaded",o),h(window,"load",o)}else f(e,/m/,t)}(function(){for(var t=e.querySelectorAll?e.querySelectorAll("a.github-button"):function(){for(var t=[],o=e.getElementsByTagName("a"),r=0,n=o.length;r<n;r++)~(" "+o[r].className+" ").replace(/[ \t\n\f\r]+/g," ").indexOf(" github-button ")&&t.push(o[r]);return t}(),o=0,r=t.length;o<r;o++)!function(e){k(e,function(t){e.parentNode.replaceChild(t,e)})}(t[o])})};

function onLoad() {
    addVersionControl();
    addGithubButton();
    parseGithubButtons();
}

window.addEventListener("load", onLoad);


================================================
FILE: docs/source/changelog.rst
================================================
Changelog
=========


v0.1.2 (2022-08-03)
-------------------
Release note: `v0.1.2 <https://github.com/frgfm/torch-scan/releases/tag/v0.1.2>`_

v0.1.1 (2020-08-04)
-------------------
Release note: `v0.1.1 <https://github.com/frgfm/torch-scan/releases/tag/v0.1.1>`_

v0.1.0 (2020-05-21)
-------------------
Release note: `v0.1.0 <https://github.com/frgfm/torch-scan/releases/tag/v0.1.0>`_


================================================
FILE: docs/source/conf.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import sys
from datetime import datetime
from pathlib import Path

sys.path.insert(0, Path().cwd().parent.parent)
import torchscan

# -- Project information -----------------------------------------------------

master_doc = "index"
project = "torchscan"
copyright = f"2020-{datetime.now().year}, François-Guillaume Fernandez"
author = "François-Guillaume Fernandez"

# The full version, including alpha/beta/rc tags
version = torchscan.__version__
release = torchscan.__version__ + "-git"


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
    "sphinx.ext.autodoc",
    "sphinx.ext.napoleon",
    "sphinx.ext.viewcode",
    "sphinx.ext.mathjax",
    "sphinxemoji.sphinxemoji",  # cf. https://sphinxemojicodes.readthedocs.io/en/stable/
    "sphinx_copybutton",
]

napoleon_use_ivar = True

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "friendly"
pygments_dark_style = "monokai"
highlight_language = "python3"

# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages.  See the documentation for
# a list of builtin themes.
#
html_theme = "furo"

html_title = "Torchscan"
html_logo = "_static/images/logo.png"
html_favicon = "_static/images/favicon.ico"
language = "en"

# Theme options are theme-specific and customize the look and feel of a theme
# further.  For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
    "footer_icons": [
        {
            "name": "GitHub",
            "url": "https://github.com/frgfm/torch-scan",
            "html": """
                <svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
                    <path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
                </svg>
            """,
            "class": "",
        },
    ],
    "source_repository": "https://github.com/frgfm/torch-scan/",
    "source_branch": "main",
    "source_directory": "docs/source/",
    "sidebar_hide_name": True,
}


# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]


# Add googleanalytics id
# ref: https://github.com/orenhecht/googleanalytics/blob/master/sphinxcontrib/googleanalytics.py
def add_ga_javascript(app, pagename, templatename, context, doctree):
    metatags = context.get("metatags", "")
    metatags += """
    <!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id={0}"></script>
<script>
  window.dataLayer = window.dataLayer || [];
  function gtag(){{dataLayer.push(arguments);}}
  gtag('js', new Date());
  gtag('config', '{0}');
</script>
    """.format(app.config.googleanalytics_id)
    context["metatags"] = metatags


def setup(app):
    app.add_config_value("googleanalytics_id", "UA-148140560-3", "html")
    app.add_css_file("css/custom.css")
    app.add_js_file("js/custom.js")
    app.connect("html-page-context", add_ga_javascript)


================================================
FILE: docs/source/index.rst
================================================
**************************************
TorchScan: inspect your PyTorch models
**************************************

The :mod:`torchscan` package provides tools for analyzing your PyTorch modules and models. Additionally to performance benchmarks, a comprehensive architecture comparison require some insights in the model complexity, its usage of computational and memory resources.


This project is meant for:

* |:zap:| **exploration**: easily assess the influence of your architecture on resource consumption
* |:woman_scientist:| **research**: quickly implement your own ideas to mitigate latency


.. toctree::
   :maxdepth: 2
   :caption: Getting Started
   :hidden:

   installing


.. toctree::
   :maxdepth: 1
   :caption: Package Reference
   :hidden:

   torchscan
   modules
   process
   utils

.. toctree::
   :maxdepth: 2
   :caption: Notes
   :hidden:

   changelog


Supported layers
^^^^^^^^^^^^^^^^

Here is the list of supported layers for FLOPS, MACs, DMAs and receptive field computation:

Non-linear activations
""""""""""""""""""""""

* `torch.nn.ReLU <https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_
* `torch.nn.ELU <https://pytorch.org/docs/stable/generated/torch.nn.ELU.html>`_
* `torch.nn.LeakyReLU <https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html>`_
* `torch.nn.ReLU6 <https://pytorch.org/docs/stable/generated/torch.nn.ReLU6.html>`_
* `torch.nn.Tanh <https://pytorch.org/docs/stable/generated/torch.nn.Tanh.html>`_
* `torch.nn.Sigmoid <https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html>`_

Linear layers
"""""""""""""

* `torch.nn.Identity <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_
* `torch.nn.Linear <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_

Convolutions
""""""""""""

* `torch.nn.Conv1d <https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html>`_
* `torch.nn.Conv2d <https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html>`_
* `torch.nn.Conv3d <https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html>`_
* `torch.nn.ConvTranspose1d <https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html>`_
* `torch.nn.ConvTranspose2d <https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html>`_
* `torch.nn.ConvTranspose3d <https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html>`_

Pooling
"""""""

* `torch.nn.MaxPool1d <https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html>`_
* `torch.nn.MaxPool2d <https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html>`_
* `torch.nn.MaxPool3d <https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html>`_
* `torch.nn.AvgPool1d <https://pytorch.org/docs/stable/generated/torch.nn.AvgPool1d.html>`_
* `torch.nn.AvgPool2d <https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html>`_
* `torch.nn.AvgPool3d <https://pytorch.org/docs/stable/generated/torch.nn.AvgPool3d.html>`_
* `torch.nn.AdaptiveMaxPool1d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool1d.html>`_
* `torch.nn.AdaptiveMaxPool2d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool2d.html>`_
* `torch.nn.AdaptiveMaxPool3d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveMaxPool3d.html>`_
* `torch.nn.AdaptiveAvgPool1d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool1d.html>`_
* `torch.nn.AdaptiveAvgPool2d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html>`_
* `torch.nn.AdaptiveAvgPool3d <https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool3d.html>`_

Normalization
"""""""""""""

* `torch.nn.BatchNorm1d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html>`_
* `torch.nn.BatchNorm2d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_
* `torch.nn.BatchNorm3d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm3d.html>`_

Other
"""""

* `torch.nn.Flatten <https://pytorch.org/docs/stable/generated/torch.nn.Linear.html>`_
* `torch.nn.Dropout <https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_


*Please note that the functional API of PyTorch is not supported.*


================================================
FILE: docs/source/installing.rst
================================================

************
Installation
************

This library requires `Python <https://www.python.org/downloads/>`_ 3.6 or higher.

Via Python Package
==================

Install the last stable release of the package using `pip <https://pip.pypa.io/en/stable/installation/>`_:

.. code:: bash

    pip install torchscan


Via Conda
=========

Install the last stable release of the package using `conda <https://docs.conda.io/en/latest/>`_:

.. code:: bash

    conda install -c frgfm torchscan


Via Git
=======

Install the library in developer mode:

.. code:: bash

    git clone https://github.com/frgfm/torch-scan.git
    pip install -e torch-scan/.


================================================
FILE: docs/source/modules.rst
================================================
torchscan.modules
=================

The modules subpackage contains tools for inspection of modules.

.. currentmodule:: torchscan.modules


FLOPs
-----
Related to the number of floating point operations performed during model inference.

.. autofunction:: module_flops


MACs
-----
Related to the number of multiply-accumulate operations performed during model inference

.. autofunction:: module_macs


DMAs
----
Related to the number of direct memory accesses during model inference

.. autofunction:: module_dmas


Receptive field
---------------
Related to the effective receptive field of a layer

.. autofunction:: module_rf


================================================
FILE: docs/source/process.rst
================================================
torchscan.process
=================

The process subpackage contains tools regarding active Python processes.

The following models are available:

.. automodule:: torchscan.process
.. currentmodule:: torchscan.process


.. autofunction:: get_process_gpu_ram


================================================
FILE: docs/source/torchscan.rst
================================================
torchscan
=========


.. currentmodule:: torchscan


Crawler
~~~~~~~

.. autofunction:: crawl_module
.. autofunction:: summary


================================================
FILE: docs/source/utils.rst
================================================
torchscan.utils
===============

.. currentmodule:: torchscan.utils

.. autofunction:: format_info

.. autofunction:: aggregate_info


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "torchscan"
description = "Useful information about your Pytorch module"
authors = [
    {name = "François-Guillaume Fernandez", email = "fg-feedback@protonmail.com"}
]
readme = "README.md"
requires-python = ">=3.8,<4"
license = {file = "LICENSE"}
keywords = ["pytorch", "deep learning", "summary", "memory", "ram"]
classifiers = [
    "Development Status :: 4 - Beta",
    "Intended Audience :: Developers",
    "Intended Audience :: Science/Research",
    "License :: OSI Approved :: Apache Software License",
    "Natural Language :: English",
    "Operating System :: OS Independent",
    "Programming Language :: Python :: 3",
    "Programming Language :: Python :: 3.8",
    "Programming Language :: Python :: 3.9",
    "Programming Language :: Python :: 3.10",
    "Programming Language :: Python :: 3.11",
    "Topic :: Scientific/Engineering",
    "Topic :: Scientific/Engineering :: Mathematics",
    "Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dynamic = ["version"]
dependencies = [
    "torch>=2.0.0,<3.0.0",
]

[project.optional-dependencies]
test = [
    "pytest>=7.3.2",
    "pytest-cov>=3.0.0,<5.0.0",
    "pytest-pretty>=1.0.0,<2.0.0",
]
quality = [
    "ruff==0.6.4",
    "mypy==1.14.0",
    "pre-commit>=3.0.0,<4.0.0",
]
docs = [
    "sphinx>=3.0.0,!=3.5.0",
    "furo>=2022.3.4",
    "sphinxemoji>=0.1.8",
    "sphinx-copybutton>=0.3.1",
    # Indirect deps
    # cf. https://github.com/readthedocs/readthedocs.org/issues/9038
    "Jinja2<3.1",
]
dev = [
    # test
    "pytest>=7.3.2",
    "pytest-cov>=3.0.0,<5.0.0",
    "pytest-pretty>=1.0.0,<2.0.0",
    # style
    "ruff==0.6.4",
    "mypy==1.14.0",
    "pre-commit>=3.0.0,<4.0.0",
    # docs
    "sphinx>=3.0.0,!=3.5.0",
    "furo>=2022.3.4",
    "sphinxemoji>=0.1.8",
    "sphinx-copybutton>=0.3.1",
    "Jinja2<3.1",
]

[project.urls]
documentation = "https://frgfm.github.io/torch-scan"
repository = "https://github.com/frgfm/torch-scan"
tracker = "https://github.com/frgfm/torch-scan/issues"
changelog = "https://frgfm.github.io/torch-scan/latest/changelog.html"

[tool.setuptools]
zip-safe = true

[tool.setuptools.packages.find]
exclude = ["docs*", "scripts*", "tests*"]

[tool.pytest.ini_options]
testpaths = ["torchscan/"]

[tool.coverage.run]
source = ["torchscan/"]

[tool.ruff]
line-length = 120
target-version = "py311"
preview = true

[tool.ruff.lint]
select = [
    "F",  # pyflakes
    "E",  # pycodestyle errors
    "W",  # pycodestyle warnings
    "I",  # isort
    "N",  # pep8-naming
    "D101", "D103",  # pydocstyle missing docstring in public function/class
    "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419",  # pydocstyle
    "YTT",  # flake8-2020
    "ANN",  # flake8-annotations
    "ASYNC",  # flake8-async
    "S",  # flake8-bandit
    "BLE",  # flake8-blind-except
    "B",  # flake8-bugbear
    "A",  # flake8-builtins
    "COM",  # flake8-commas
    "CPY",  # flake8-copyright
    "C4",  # flake8-comprehensions
    "T10",  # flake8-debugger
    "ISC",  # flake8-implicit-str-concat
    "ICN",  # flake8-import-conventions
    "LOG",  # flake8-logging
    "PIE",  # flake8-pie
    "T20",  # flake8-print
    "PYI",  # flake8-pyi
    "PT",  # flake8-pytest-style
    "Q",    # flake8-quotes
    "RET",  # flake8-return
    "SLF",  # flake8-self
    "SIM",  # flake8-simplify
    "ARG",  # flake8-unused-arguments
    "PTH",  # flake8-use-pathlib
    "PERF",  # perflint
    "NPY",  # numpy
    "FAST",  # fastapi
    "FURB",  # refurb
    "RUF",  # ruff specific
    "N",  # pep8-naming
]
ignore = [
    "E501",  # line too long, handled by black
    "B008",  # do not perform function calls in argument defaults
    "B904",  # raise from
    "C901",  # too complex
    "F403",  # star imports
    "E731",  # lambda assignment
    "C416",  # list comprehension to list()
    "ANN101",  # missing type annotations on self
    "ANN102",  # missing type annotations on cls
    "ANN002",  # missing type annotations on *args
    "ANN003",  # missing type annotations on **kwargs
    "COM812",  # trailing comma missing
    "N812",  # lowercase imported as non-lowercase
    "ISC001",  # implicit string concatenation (handled by format)
    "ANN401",  # Dynamically typed expressions (typing.Any) are disallowed
    "SLF001",  # Private member accessed
]
exclude = [".git"]

[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"

[tool.ruff.lint.isort]
known-first-party = ["torchscan", "app"]
known-third-party = ["torch", "torchvision"]

[tool.ruff.lint.per-file-ignores]
"**/__init__.py" = ["I001", "F401", "CPY001"]
"scripts/**.py" = ["D", "T201", "N812", "S101", "ANN"]
".github/**.py" = ["D", "T201", "S602", "S101", "ANN"]
"docs/**.py" = ["E402", "D103", "ANN", "A001", "ARG001"]
"tests/**.py" = ["D101", "D103", "CPY001", "S101", "PT011", "ANN", "SLF001"]
"demo/**.py" = ["D103", "ANN"]
"setup.py" = ["T201"]
"torchscan/process/memory.py" = ["S60"]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"


[tool.mypy]
python_version = "3.11"
files = "torchscan/"
show_error_codes = true
pretty = true
warn_unused_ignores = true
warn_redundant_casts = true
no_implicit_optional = true
disallow_untyped_calls = true
check_untyped_defs = true
implicit_reexport = false
disallow_untyped_defs = true
explicit_package_bases = true


================================================
FILE: scripts/benchmark.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

"""
Torchvision benchmark
"""

import torch
from torchvision import models

from torchscan import crawl_module

TORCHVISION_MODELS = [
    "alexnet",
    "googlenet",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19",
    "vgg19_bn",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "inception_v3",
    "squeezenet1_0",
    "squeezenet1_1",
    "wide_resnet50_2",
    "wide_resnet101_2",
    "densenet121",
    "densenet161",
    "densenet169",
    "densenet201",
    "resnext50_32x4d",
    "resnext101_32x8d",
    "mobilenet_v2",
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
    "shufflenet_v2_x1_5",
    "shufflenet_v2_x2_0",
    "mnasnet0_5",
    "mnasnet0_75",
    "mnasnet1_0",
    "mnasnet1_3",
]


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    margin = 4
    headers = ["Model", "Params (M)", "FLOPs (G)", "MACs (G)", "DMAs (G)", "RF"]
    max_w = [20, 10, 10, 10, 10, 10]

    info_str = [(" " * margin).join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, max_w, strict=False)])]
    info_str.append("-" * len(info_str[0]))
    print("\n".join(info_str))
    for name in TORCHVISION_MODELS:
        model = models.__dict__[name]().eval().to(device)
        dsize = (3, 224, 224)
        if "inception" in name:
            dsize = (3, 299, 299)
        model_info = crawl_module(model, dsize)

        tot_params = sum(layer["grad_params"] + layer["nograd_params"] for layer in model_info["layers"])
        tot_flops = sum(layer["flops"] for layer in model_info["layers"])
        tot_macs = sum(layer["macs"] for layer in model_info["layers"])
        tot_dmas = sum(layer["dmas"] for layer in model_info["layers"])
        rf = model_info["layers"][0]["rf"]
        print(
            f"{name:<{max_w[0]}} | {tot_params / 1e6:<{max_w[1]}.2f} | {tot_flops / 1e9:<{max_w[2]}.2f} | "
            f"{tot_macs / 1e9:<{max_w[3]}.2f} | {tot_dmas / 1e9:<{max_w[4]}.2f} | {rf:<{max_w[5]}.0f}"
        )


if __name__ == "__main__":
    main()


================================================
FILE: setup.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.


import os
from pathlib import Path

from setuptools import setup

PKG_NAME = "torchscan"
VERSION = os.getenv("BUILD_VERSION", "0.2.0.dev0")


if __name__ == "__main__":
    print(f"Building wheel {PKG_NAME}-{VERSION}")

    # Dynamically set the __version__ attribute
    cwd = Path(__file__).parent.absolute()
    with cwd.joinpath("torchscan", "version.py").open("w", encoding="utf-8") as f:
        f.write(f"__version__ = '{VERSION}'\n")

    setup(name=PKG_NAME, version=VERSION)


================================================
FILE: tests/test_crawler.py
================================================
import io
import sys
from collections import OrderedDict

import pytest
import torch.nn as nn

from torchscan import crawler


def test_apply():
    multi_convs = nn.Sequential(nn.Conv2d(16, 32, 3), nn.Conv2d(32, 64, 3))
    mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs)

    # Tag module attributes
    def tag_name(mod, name):
        mod.__depth__ = len(name.split(".")) - 1
        mod.__name__ = name.rpartition(".")[-1]

    crawler.apply(mod, tag_name)

    assert mod[1][1].__depth__ == 2
    assert mod[1][1].__name__ == "1"


def test_crawl_module():
    mod = nn.Conv2d(3, 8, 3)

    res = crawler.crawl_module(mod, (3, 32, 32))
    assert isinstance(res, dict)
    assert res["overall"]["grad_params"] == 224
    assert res["layers"][0]["output_shape"] == (-1, 8, 30, 30)


def test_summary():
    mod = nn.Conv2d(3, 8, 3)

    # Redirect stdout with StringIO object
    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32))
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[7] == "Total params: 224"

    # Check receptive field
    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32), receptive_field=True)
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[1].rpartition("  ")[-1] == "Receptive field"
    assert captured_output.getvalue().split("\n")[3].split()[-1] == "3"
    # Check effective stats
    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32), receptive_field=True, effective_rf_stats=True)
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[1].rpartition("  ")[-1] == "Effective padding"
    assert captured_output.getvalue().split("\n")[3].split()[-1] == "0"

    # Max depth > model hierarchy
    with pytest.raises(ValueError):
        crawler.summary(mod, (3, 32, 32), max_depth=1)

    mod = nn.Sequential(
        OrderedDict([
            ("features", nn.Sequential(nn.Conv2d(3, 8, 3), nn.ReLU(inplace=True))),
            ("pool", nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(1))),
            ("classifier", nn.Linear(8, 1)),
        ])
    )

    captured_output = io.StringIO()
    sys.stdout = captured_output
    crawler.summary(mod, (3, 32, 32), max_depth=1)
    # Reset redirect.
    sys.stdout = sys.__stdout__
    assert captured_output.getvalue().split("\n")[4].startswith("├─features ")


================================================
FILE: tests/test_modules.py
================================================
import pytest
import torch
from torch import nn

from torchscan import modules


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()


def test_module_flops_warning():
    with pytest.warns(UserWarning):
        modules.module_flops(MyModule(), None, None)


@pytest.mark.parametrize(
    ("mod", "input_shape", "output_shape", "expected_val"),
    [
        # Check for unknown module that it returns 0 and throws a warning
        (MyModule(), (1,), (1,), 0),
        # Fully-connected
        (nn.Linear(8, 4), (1, 8), (1, 4), 4 * (2 * 8 - 1) + 4),
        (nn.Linear(8, 4, bias=False), (1, 8), (1, 4), 4 * (2 * 8 - 1)),
        (nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 2 * (4 * (2 * 8 - 1) + 4)),
        # Activations
        (nn.Identity(), (1, 8), (1, 8), 0),
        (nn.Flatten(), (1, 8), (1, 8), 0),
        (nn.ReLU(), (1, 8), (1, 8), 8),
        (nn.ELU(), (1, 8), (1, 8), 48),
        (nn.LeakyReLU(), (1, 8), (1, 8), 32),
        (nn.ReLU6(), (1, 8), (1, 8), 16),
        (nn.Tanh(), (1, 8), (1, 8), 48),
        (nn.Sigmoid(), (1, 8), (1, 8), 32),
        # BN
        (nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 144 + 32 + 32 * 3 + 48),
        # Pooling
        (nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),
        (nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),
        (nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),
        (nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),
        (nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),
        (nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),
        # Dropout
        (nn.Dropout(), (1, 8), (1, 8), 8),
        (nn.Dropout(p=0), (1, 8), (1, 8), 0),
        # Conv
        (nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 388800),
        (nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 499408),
    ],
)
def test_module_flops(mod, input_shape, output_shape, expected_val):
    assert modules.module_flops(mod, (torch.zeros(input_shape),), torch.zeros(output_shape)) == expected_val


def test_transformer_flops():
    mod = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=3)
    src = torch.rand((10, 16, 64))
    tgt = torch.rand((20, 16, 64))
    assert modules.module_flops(mod, (src, tgt), mod(src, tgt)) == 774952841


def test_module_macs_warning():
    with pytest.warns(UserWarning):
        modules.module_macs(MyModule(), None, None)


@pytest.mark.parametrize(
    ("mod", "input_shape", "output_shape", "expected_val"),
    [
        # Check for unknown module that it returns 0 and throws a warning
        (MyModule(), (1,), (1,), 0),
        # Fully-connected
        (nn.Linear(8, 4), (1, 8), (1, 4), 8 * 4),
        (nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 8 * 4 * 2),
        # Activations
        (nn.ReLU(), (1, 8), (1, 8), 0),
        # BN
        (nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 64 + 24 + 56 + 32),
        # Pooling
        (nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),
        (nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),
        (nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),
        (nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32),
        (nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),
        (nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32),
        # Dropout
        (nn.Dropout(), (1, 8), (1, 8), 0),
        # Conv
        (nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 194400),
        (nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 249704),
    ],
)
def test_module_macs(mod, input_shape, output_shape, expected_val):
    assert modules.module_macs(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val


def test_module_dmas_warning():
    with pytest.warns(UserWarning):
        modules.module_dmas(MyModule(), None, None)


@pytest.mark.parametrize(
    ("mod", "input_shape", "output_shape", "expected_val"),
    [
        # Check for unknown module that it returns 0 and throws a warning
        (MyModule(), (1,), (1,), 0),
        # Fully-connected
        (nn.Linear(8, 4), (1, 8), (1, 4), 4 * (8 + 1) + 8 + 4),
        (nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 4 * (8 + 1) + 2 * (8 + 4)),
        # Activations
        (nn.Identity(), (1, 8), (1, 8), 8),
        (nn.Flatten(), (1, 8), (1, 8), 16),
        (nn.ReLU(), (1, 8), (1, 8), 8 * 2),
        (nn.ReLU(inplace=True), (1, 8), (1, 8), 8),
        (nn.ELU(), (1, 8), (1, 8), 17),
        (nn.Tanh(), (1, 8), (1, 8), 24),
        (nn.Sigmoid(), (1, 8), (1, 8), 16),
        # BN
        (nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 32 + 17 + 16 + 1 + 17 + 32),
        # Pooling
        (nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),
        (nn.MaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),
        (nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),
        (nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32),
        # Dropout
        (nn.Dropout(), (1, 8), (1, 8), 17),
        # Conv
        (nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 201824),
        (nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 259178),
    ],
)
def test_module_dmas(mod, input_shape, output_shape, expected_val):
    assert modules.module_dmas(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val


# @torch.no_grad()
# def test_module_rf(self):

#     # Check for unknown module that it returns 0 and throws a warning
#     self.assertEqual(modules.module_rf(MyModule(), None, None), (1, 1, 0))
#     self.assertWarns(UserWarning, modules.module_rf, MyModule(), None, None)

#     # Common unit tests
#     # Linear
#     self.assertEqual(modules.module_rf(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))),
#                      (1, 1, 0))
#     # Activation
#     self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     self.assertEqual(modules.module_rf(nn.Tanh(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))
#     # Conv
#     input_t = torch.rand((1, 3, 32, 32))
#     mod = nn.Conv2d(3, 8, 3)
#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (3, 1, 0))
#     # Check for dilation support
#     mod = nn.Conv2d(3, 8, 3, dilation=2)
#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (5, 1, 0))
#     # ConvTranspose
#     mod = nn.ConvTranspose2d(3, 8, 3)
#     self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (-3, 1, 0))
#     # BN
#     self.assertEqual(modules.module_rf(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))),
#                      (1, 1, 0))

#     # Pooling
#     self.assertEqual(modules.module_rf(nn.MaxPool2d((2, 2)),
#                                        torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))),
#                      (2, 2, 0))
#     self.assertEqual(modules.module_rf(nn.AdaptiveMaxPool2d((2, 2)),
#                                        torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))),
#                      (2, 2, 0))

#     # Dropout
#     self.assertEqual(modules.module_rf(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0))


================================================
FILE: tests/test_process.py
================================================
import os

import torch

from torchscan import process


def test_get_process_gpu_ram():
    if torch.cuda.is_initialized:
        assert process.get_process_gpu_ram(os.getpid()) >= 0
    else:
        assert process.get_process_gpu_ram(os.getpid()) == 0


================================================
FILE: tests/test_utils.py
================================================
import pytest

from torchscan import utils


def test_format_name():
    name = "mymodule"
    assert utils.format_name(name) == name
    assert utils.format_name(name, depth=1) == f"├─{name}"
    assert utils.format_name(name, depth=3) == f"|    |    └─{name}"


def test_wrap_string():
    example = ".".join(["a" for _ in range(10)])
    max_len = 10
    wrap = "[...]"

    assert utils.wrap_string(example, max_len, mode="end") == example[: max_len - len(wrap)] + wrap
    assert utils.wrap_string(example, max_len, mode="mid") == f"{example[: max_len - 2 - len(wrap)]}{wrap}.a"
    assert utils.wrap_string(example, len(example), mode="end") == example
    with pytest.raises(ValueError):
        _ = utils.wrap_string(example, max_len, mode="test")


@pytest.mark.parametrize(
    ("input_val", "num_val", "unit"),
    [
        (3e14, 300, "T"),
        (3e10, 30, "G"),
        (3e7, 30, "M"),
        (15e3, 15, "k"),
        (500, 500, ""),
    ],
)
def test_unit_scale(input_val, num_val, unit):
    assert utils.unit_scale(input_val) == (num_val, unit)


================================================
FILE: torchscan/__init__.py
================================================
from contextlib import suppress
from torchscan import modules, process, utils
from torchscan.crawler import *

with suppress(ImportError):
    from .version import __version__


================================================
FILE: torchscan/crawler.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch.nn import Module

from .modules import module_dmas, module_flops, module_macs, module_rf
from .process import get_process_gpu_ram
from .utils import aggregate_info, format_info

__all__ = ["crawl_module", "summary"]


def apply(module: Module, fn: Callable[[Module, str], None], name: Optional[str] = None) -> None:
    """Modified version of `torch.nn.Module.apply` method

    Args:
        module: target module
        fn: function to apply to each module
        name: name of the current module
    """
    if name is None:
        name = module.__class__.__name__.lower()
    fn(module, name)
    for n, m in module.named_children():
        apply(m, fn, f"{name}.{n}")


def crawl_module(
    module: Module,
    input_shape: Union[List[Tuple[int, ...]], Tuple[int, ...]],
    dtype: Optional[Union[torch.dtype, Iterable[torch.dtype]]] = None,
) -> Dict[str, Any]:
    """Retrieves module information for an expected input tensor shape

    >>> import torch.nn as nn
    >>> from torchscan import summary
    >>> mod = nn.Conv2d(3, 8, 3)
    >>> module_info = crawl_module(mod, (3, 224, 224))

    Args:
        module: module to inspect
        input_shape: expected input shapes
        dtype: data type of each input argument to the module
    Returns:
        layer and overhead information
    """
    # Get device and data types from model
    p = next(module.parameters())
    device = p.device

    cuda_overhead, framework_overhead = 0.0, 0.0
    if torch.cuda.is_available():
        # Process RAM - allocator RAM
        cuda_overhead = get_process_gpu_ram(os.getpid()) - (torch.cuda.memory_reserved() / 1024**2)
        # Allocator RAM - Used RAM
        framework_overhead = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024**2

    # input
    if not isinstance(input_shape, list):
        input_shape = [input_shape]
    if dtype is None:
        dtype = p.data.dtype
    if isinstance(dtype, torch.dtype):
        dtype = [dtype] * len(input_shape)
    # Tensor arguments
    input_ts = [
        torch.rand(1, *in_shape).to(dtype=_dtype, device=device)
        for in_shape, _dtype in zip(input_shape, dtype, strict=False)
    ]

    pre_fw_handles, post_fw_handles = [], []
    pre_hook_tracker: Dict[int, Any] = {}
    post_hook_tracker: Dict[int, Any] = {}

    # Hook definition
    def _hook_info(module: Module, name: str) -> None:
        def _pre_hook(module: Module, inp: torch.Tensor) -> None:
            """Pre-forward hook"""
            # Check that another hook has not been triggered at this forward stage
            if not pre_hook_tracker[id(module)]["is_used"] and (
                pre_hook_tracker[id(module)]["target"] == pre_hook_tracker[id(module)]["current"]
            ):
                # Add information
                # Params
                grad_params, nograd_params, param_size = 0, 0, 0
                num_buffers, buffer_size = 0, 0
                is_shared = False
                if not any(module.children()):
                    # Parameters
                    for p in module.parameters():
                        if id(p) not in param_ids:
                            if p.requires_grad:
                                grad_params += p.data.numel()
                            else:
                                nograd_params += p.data.numel()
                            param_size += p.data.numel() * p.data.element_size()
                            param_ids.append(id(p))
                        else:
                            is_shared = True
                    # Buffers
                    for b in module.buffers():
                        if id(b) not in param_ids:
                            num_buffers += b.numel()
                            buffer_size += b.numel() * b.element_size()
                            param_ids.append(id(b))
                        else:
                            is_shared = True

                if call_idxs.get(id(module)) is None:
                    call_idxs[id(module)] = [len(info)]
                else:
                    call_idxs[id(module)].append(len(info))

                info.append({
                    "name": name.rpartition(".")[-1],
                    "depth": len(name.split(".")) - 1,
                    "type": module.__class__.__name__,
                    "input_shape": (-1, *inp[0][0].shape[1:]),
                    "output_shape": None,
                    "grad_params": grad_params,
                    "nograd_params": nograd_params,
                    "param_size": param_size,
                    "num_buffers": num_buffers,
                    "buffer_size": buffer_size,
                    "flops": 0,
                    "macs": 0,
                    "dmas": 0,
                    "rf": 1,
                    "s": 1,
                    "p": 0,
                    "is_shared": is_shared,
                    "is_leaf": not any(module.children()),
                })
                # Mark the next hook for execution
                pre_hook_tracker[id(module)]["target"] += 1
                # Current pass already used one of the hooks
                pre_hook_tracker[id(module)]["is_used"] = True
            pre_hook_tracker[id(module)]["current"] += 1
            # All the hooks have been checked, reset the temporary values
            if pre_hook_tracker[id(module)]["current"] == len(module._forward_pre_hooks):
                pre_hook_tracker[id(module)]["current"] = 0
                pre_hook_tracker[id(module)]["is_used"] = False

        def _fwd_hook(module: Module, inputs: Tuple[torch.Tensor, ...], out: torch.Tensor) -> None:
            """Post-forward hook"""
            # Check that another hook has not been triggered at this forward stage
            if not post_hook_tracker[id(module)]["is_used"] and (
                post_hook_tracker[id(module)]["target"] == post_hook_tracker[id(module)]["current"]
            ):
                # Write information
                # Retrieve forward index
                if len(call_idxs[id(module)]) == 1:
                    fw_idx = call_idxs[id(module)][0]
                else:
                    # The first dictionary with output_shape=None is the correct one
                    for _idx in call_idxs[id(module)]:
                        if info[_idx]["output_shape"] is None:
                            fw_idx = _idx
                            break

                if any(module.children()):
                    tot_flops, tot_macs, tot_dmas = 0, 0, 0
                    current_rf, current_stride, current_padding = 1.0, 1.0, 0.0
                else:
                    # Compute stats for standalone layers
                    tot_flops = module_flops(module, inputs, out)
                    tot_macs = module_macs(module, inputs[0], out)
                    tot_dmas = module_dmas(module, inputs[0], out)
                    current_rf, current_stride, current_padding = module_rf(module, inputs[0], out)

                # Update layer information
                info[fw_idx]["output_shape"] = (-1, *out.shape[1:])
                # Add them, since some modules can be used several times
                info[fw_idx]["flops"] = tot_flops
                info[fw_idx]["macs"] = tot_macs
                info[fw_idx]["dmas"] = tot_dmas
                # Compute receptive field
                info[fw_idx]["rf"] = current_rf
                info[fw_idx]["s"] = current_stride
                info[fw_idx]["p"] = current_padding

                # Mark the next hook for execution
                post_hook_tracker[id(module)]["target"] += 1
                # Current pass already used one of the hooks
                post_hook_tracker[id(module)]["is_used"] = True
            post_hook_tracker[id(module)]["current"] += 1
            # All the hooks have been checked, reset the temporary values
            if post_hook_tracker[id(module)]["current"] == len(module._forward_pre_hooks):
                post_hook_tracker[id(module)]["current"] = 0
                post_hook_tracker[id(module)]["is_used"] = False

        pre_fw_handles.append(module.register_forward_pre_hook(_pre_hook))  # type: ignore[arg-type]
        post_fw_handles.append(module.register_forward_hook(_fwd_hook))
        # Handle modules that are used multiple times (with several hooks)
        pre_hook_tracker[id(module)] = {"current": 0, "target": 0, "is_used": False}
        post_hook_tracker[id(module)] = {"current": 0, "target": 0, "is_used": False}

    # Hook model
    info: List[Dict[str, Any]] = []
    param_ids: List[int] = []
    call_idxs: Dict[int, List[int]] = {}
    apply(module, _hook_info)

    # Forward
    with torch.no_grad():
        module(*input_ts)

    # Removes all hooks using their handles
    for handle in pre_fw_handles:
        handle.remove()
    for handle in post_fw_handles:
        handle.remove()

    reserved_ram, diff_ram = 0.0, 0.0
    if torch.cuda.is_available():
        reserved_ram = torch.cuda.memory_reserved() / 1024**2
        diff_ram = (torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024**2
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    grad_params, nograd_params, param_size = 0, 0, 0
    num_buffers, buffer_size = 0, 0
    for p in module.parameters():
        if p.requires_grad:
            grad_params += p.data.numel()
        else:
            nograd_params += p.data.numel()
        param_size += p.data.numel() * p.data.element_size()
    for b in module.buffers():
        num_buffers += b.numel()
        buffer_size += b.numel() * b.element_size()

    # Update cumulative receptive field
    _rf, _s, _p = 1, 1, 0
    for fw_idx, _layer in enumerate(info):
        _rf += _s * (_layer["rf"] - 1)
        _p += _s * _layer["p"]
        _s *= _layer["s"]
        info[fw_idx]["rf"] = _rf
        info[fw_idx]["s"] = _s
        info[fw_idx]["p"] = _p

    return {
        "overheads": {
            "cuda": {
                "pre": cuda_overhead,
                "fwd": get_process_gpu_ram(os.getpid()) - reserved_ram,
            },
            "framework": {"pre": framework_overhead, "fwd": diff_ram},
        },
        "layers": info,
        "overall": {
            "grad_params": grad_params,
            "nograd_params": nograd_params,
            "param_size": param_size,
            "num_buffers": num_buffers,
            "buffer_size": buffer_size,
        },
    }


def summary(
    module: Module,
    input_shape: Tuple[int, ...],
    wrap_mode: str = "mid",
    max_depth: Optional[int] = None,
    receptive_field: bool = False,
    effective_rf_stats: bool = False,
) -> None:
    """Print module summary for an expected input tensor shape

    >>> import torch.nn as nn
    >>> from torchscan import summary
    >>> mod = nn.Conv2d(3, 8, 3)
    >>> summary(mod, (3, 224, 224), receptive_field=True)

    Args:
        module: module to inspect
        input_shape: expected input shapes (don't include batch size)
        wrap_mode: if a value is too long, where the wrapping should be performed
        max_depth: maximum depth of layer information
        receptive_field: whether receptive field estimation should be performed
        effective_rf_stats: if `receptive_field` is True, displays effective stride and padding
    """
    # Get the summary dict
    module_info = crawl_module(module, input_shape)
    # Aggregate until max_depth
    if isinstance(max_depth, int):
        module_info = aggregate_info(module_info, max_depth)
    # Format it and print it
    print(format_info(module_info, wrap_mode, receptive_field, effective_rf_stats))  # noqa T201


================================================
FILE: torchscan/modules/__init__.py
================================================
from .flops import *
from .macs import *
from .memory import *
from .receptive import *


================================================
FILE: torchscan/modules/flops.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import warnings
from functools import reduce
from operator import mul
from typing import Tuple

import torch
from torch import Tensor, nn
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd

__all__ = ["module_flops"]


def module_flops(module: Module, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
    """Estimate the number of floating point operations performed by the module

    Args:
        module: PyTorch module
        inputs: input to the module
        out: output of the module
    Returns:
        number of FLOPs
    """
    if isinstance(module, (nn.Identity, nn.Flatten)):
        return 0
    if isinstance(module, nn.Linear):
        return flops_linear(module, inputs)
    if isinstance(module, nn.ReLU):
        return flops_relu(module, inputs)
    if isinstance(module, nn.ELU):
        return flops_elu(module, inputs)
    if isinstance(module, nn.LeakyReLU):
        return flops_leakyrelu(module, inputs)
    if isinstance(module, nn.ReLU6):
        return flops_relu6(module, inputs)
    if isinstance(module, nn.Tanh):
        return flops_tanh(module, inputs)
    if isinstance(module, nn.Sigmoid):
        return flops_sigmoid(module, inputs)
    if isinstance(module, _ConvTransposeNd):
        return flops_convtransposend(module, inputs, out)
    if isinstance(module, _ConvNd):
        return flops_convnd(module, inputs, out)
    if isinstance(module, _BatchNorm):
        return flops_bn(module, inputs)
    if isinstance(module, _MaxPoolNd):
        return flops_maxpool(module, inputs, out)
    if isinstance(module, _AvgPoolNd):
        return flops_avgpool(module, inputs, out)
    if isinstance(module, _AdaptiveMaxPoolNd):
        return flops_adaptive_maxpool(module, inputs, out)
    if isinstance(module, _AdaptiveAvgPoolNd):
        return flops_adaptive_avgpool(module, inputs, out)
    if isinstance(module, nn.Dropout):
        return flops_dropout(module, inputs)
    if isinstance(module, nn.Transformer):
        return flops_transformer(module, inputs)
    warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1)
    return 0


def flops_linear(module: nn.Linear, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.Linear`"""
    # batch size * out_chan * in_chan
    num_out_feats = module.out_features * reduce(mul, inputs[0].shape[:-1])
    mm_flops = num_out_feats * (2 * module.in_features - 1)
    bias_flops = num_out_feats if module.bias is not None else 0

    return mm_flops + bias_flops


def flops_sigmoid(_: nn.Sigmoid, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.Sigmoid`"""
    # For each element, mul by -1, exp it, add 1, div
    return inputs[0].numel() * 4


def flops_relu(_: nn.ReLU, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.ReLU`"""
    # Each element is compared to 0
    return inputs[0].numel()


def flops_elu(_: nn.ELU, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.ELU`"""
    # For each element, compare it to 0, exp it, sub 1, mul by alpha, compare it to 0 and sum both
    return inputs[0].numel() * 6


def flops_leakyrelu(_: nn.LeakyReLU, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.LeakyReLU`"""
    # For each element, compare it to 0 (max), compare it to 0 (min), mul by slope and sum both
    return inputs[0].numel() * 4


def flops_relu6(_: nn.ReLU6, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.ReLU6`"""
    # For each element, compare it to 0 (max), compare it to 0 (min), mul by slope and sum both
    return inputs[0].numel() * 2


def flops_tanh(_: nn.Tanh, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.Tanh`"""
    # For each element, exp it, mul by -1 and exp it, divide the sub by the add
    return inputs[0].numel() * 6


def flops_dropout(module: nn.Dropout, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.Dropout`"""
    if module.p > 0:
        # Sample a random number for each input element
        return inputs[0].numel()
    return 0


def flops_convtransposend(module: _ConvTransposeNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
    """FLOPs estimation for `torch.nn.modules.conv._ConvTranposeNd`"""
    # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)
    # Define min and max sizes
    padding_flops = len(module.kernel_size) * 8

    # Once padding is determined, the operations are almost identical to those of a convolution
    conv_flops = flops_convnd(module, inputs, out)

    return padding_flops + conv_flops


def flops_convnd(module: _ConvNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
    """FLOPs estimation for `torch.nn.modules.conv._ConvNd`"""
    # For each position, # mult = kernel size, # adds = kernel size - 1
    window_flops_per_chan = 2 * reduce(mul, module.kernel_size) - 1
    # Connections to input channels is controlled by the group parameter
    effective_in_chan = inputs[0].shape[1] // module.groups
    # N * flops + (N - 1) additions
    window_flops = effective_in_chan * window_flops_per_chan + (effective_in_chan - 1)
    conv_flops = out.numel() * window_flops

    # Each output element gets a bias addition
    bias_flops = out.numel() if module.bias is not None else 0

    return conv_flops + bias_flops


def flops_bn(module: _BatchNorm, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""
    # for each channel, add eps and running_var, sqrt it
    norm_ops = module.num_features * 2
    # For each element, sub running_mean, div by denom
    norm_ops += inputs[0].numel() * 2
    # For each element, mul by gamma, add beta
    scale_ops = inputs[0].numel() * 2 if module.affine else 0
    bn_flops = norm_ops + scale_ops

    # Count tracking stats update ops
    # cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101
    tracking_flops = 0
    if module.track_running_stats and module.training:
        # exponential_average_factor
        if module.momentum is None:
            tracking_flops += 1
        # running_mean: by channel, sum values and div by batch size
        tracking_flops += inputs[0].numel()
        # running_var: by channel, sub mean and square values, sum them, divide by batch size
        tracking_flops += 3 * inputs[0].numel()
        # Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1)
        tracking_flops += 2 * module.num_features * 3

    return bn_flops + tracking_flops


def flops_maxpool(module: _MaxPoolNd, _: Tuple[Tensor, ...], out: Tensor) -> int:
    """FLOPs estimation for `torch.nn.modules.pooling._MaxPoolNd`"""
    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

    # for each spatial output element, check max element in kernel scope
    return out.numel() * (k_size - 1)


def flops_avgpool(module: _AvgPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
    """FLOPs estimation for `torch.nn.modules.pooling._AvgPoolNd`"""
    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

    # for each spatial output element, sum elements in kernel scope and div by kernel size
    return out.numel() * (k_size - 1 + inputs[0].ndim - 2)


def flops_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
    """FLOPs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`"""
    # Approximate kernel_size using ratio of spatial shapes between input and output
    kernel_size = tuple(
        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
        for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:], strict=False)
    )

    # for each spatial output element, check max element in kernel scope
    return out.numel() * (reduce(mul, kernel_size) - 1)


def flops_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inputs: Tuple[Tensor, ...], out: Tensor) -> int:
    """FLOPs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`"""
    # Approximate kernel_size using ratio of spatial shapes between input and output
    kernel_size = tuple(
        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
        for i_size, o_size in zip(inputs[0].shape[2:], out.shape[2:], strict=False)
    )

    # for each spatial output element, sum elements in kernel scope and div by kernel size
    return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))


def flops_layernorm(module: nn.LayerNorm, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""
    # Compute current mean
    norm_ops = reduce(mul, module.normalized_shape) * inputs[0].shape[: -len(module.normalized_shape)].numel()
    # current var (sub the mean, square it, sum them, divide by remaining shape)
    norm_ops += 3 * inputs[0].numel()
    # for each channel, add eps and running_var, sqrt it
    norm_ops += reduce(mul, module.normalized_shape) * 2
    # For each element, sub running_mean, div by denom
    norm_ops += inputs[0].numel() * 2
    # For each element, mul by gamma, add beta
    scale_ops = inputs[0].numel() * 2 if module.elementwise_affine else 0

    return norm_ops + scale_ops


def flops_mha(module: nn.MultiheadAttention, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.MultiheadAttention`"""
    # Input projection
    q, k, _ = inputs[:3]
    batch_size = q.shape[1]
    if module._qkv_same_embed_dim:
        tot_flops = 3 * flops_linear(
            nn.Linear(
                module.in_proj_weight.shape[1], module.in_proj_weight.shape[0], bias=module.in_proj_bias is not None
            ),
            (torch.empty((batch_size, module.in_proj_weight.shape[1])),),
        )
    else:
        tot_flops = flops_linear(
            nn.Linear(
                module.q_proj_weight.shape[1], module.q_proj_weight.shape[0], bias=module.in_proj_bias is not None
            ),
            (torch.empty((batch_size, module.q_proj_weight.shape[1])),),
        )
        tot_flops += flops_linear(
            nn.Linear(module.k_proj_weight.shape[1], module.k_proj_weight.shape[0], bias=module.bias_k is not None),
            (torch.empty((batch_size, module.k_proj_weight.shape[1])),),
        )
        tot_flops += flops_linear(
            nn.Linear(module.v_proj_weight.shape[1], module.v_proj_weight.shape[0], bias=module.bias_v is not None),
            (torch.empty((batch_size, module.v_proj_weight.shape[1])),),
        )

    # Q (L, B, embed_dim) --> (B * num_heads, L, head_dim=embed_dim / num_heads)

    # Scaled dot-product attention (cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L5083)
    # sqrt the embedding dim and div the Q with it
    tot_flops += 1 + batch_size * module.num_heads * module.head_dim * q.shape[0]
    # batched matrix multiply
    tot_flops += batch_size * module.num_heads * (q.shape[0] * k.shape[0]) * (2 * module.head_dim - 1)
    # attention mask
    if inputs[-1] is not None:
        tot_flops += batch_size * module.num_heads * (q.shape[0] * k.shape[0])

    # softmax
    tot_flops += batch_size * module.num_heads * q.shape[0] * (3 * k.shape[0] - 1)
    # dropout
    if module.dropout > 0:
        tot_flops += batch_size * module.num_heads * (q.shape[0] * k.shape[0])

    # batched matrix multiply
    tot_flops += batch_size * module.num_heads * (q.shape[0] * module.head_dim) * (2 * k.shape[0] - 1)
    # Output linear projection
    tot_flops += flops_linear(module.out_proj, (torch.empty((q.shape[0], module.out_proj.in_features)),))

    return tot_flops


def flops_transformer_encoderlayer(module: nn.TransformerEncoderLayer, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.TransformerEncoderLayer`"""
    tot_flops = flops_mha(module.self_attn, (inputs[0],) * 3)

    tot_flops += flops_dropout(module.dropout1, inputs) + inputs[0].numel()
    tot_flops += flops_layernorm(module.norm1, inputs)
    # get linear 1 output size
    tot_flops += flops_linear(module.linear1, inputs)
    tot_flops += module_flops(module.activation, inputs, torch.empty(1))  # type: ignore[arg-type]
    tot_flops += flops_dropout(module.dropout, inputs) + flops_linear(module.linear2, inputs)
    # get linear 2 output size
    tot_flops += flops_dropout(module.dropout2, inputs) + inputs[0].numel()
    tot_flops += flops_layernorm(module.norm2, inputs)

    return tot_flops


def flops_transformer_decoderlayer(module: nn.TransformerDecoderLayer, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.TransformerEncoderLayer`"""
    tot_flops = flops_mha(module.self_attn, (inputs[0],) * 3)

    tot_flops += flops_dropout(module.dropout1, inputs) + inputs[0].numel()
    tot_flops += flops_layernorm(module.norm1, inputs)

    tot_flops = flops_mha(module.multihead_attn, (inputs[0], inputs[1], inputs[1]))
    tot_flops += flops_dropout(module.dropout2, inputs) + inputs[0].numel()
    tot_flops += flops_layernorm(module.norm2, inputs)

    # get linear 1 output size
    tot_flops += flops_linear(module.linear1, inputs)
    tot_flops += module_flops(module.activation, inputs, torch.empty(1))  # type: ignore[arg-type]
    tot_flops += flops_dropout(module.dropout, inputs) + flops_linear(module.linear2, inputs)
    # get linear 2 output size
    tot_flops += flops_dropout(module.dropout3, inputs) + inputs[0].numel()
    tot_flops += flops_layernorm(module.norm3, inputs)

    return tot_flops


def flops_transformer(module: nn.Transformer, inputs: Tuple[Tensor, ...]) -> int:
    """FLOPs estimation for `torch.nn.Transformer`"""
    encoder_flops = len(module.encoder.layers) * flops_transformer_encoderlayer(module.encoder.layers[0], inputs)

    if module.encoder.norm is not None:
        encoder_flops += flops_layernorm(module.encoder.norm, inputs)

    decoder_flops = len(module.decoder.layers) * flops_transformer_decoderlayer(module.decoder.layers[0], inputs)

    if module.decoder.norm is not None:
        decoder_flops += flops_layernorm(module.decoder.norm, inputs)

    return encoder_flops + decoder_flops


================================================
FILE: torchscan/modules/macs.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import warnings
from functools import reduce
from operator import mul

from torch import Tensor, nn
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd

__all__ = ["module_macs"]


def module_macs(module: Module, inp: Tensor, out: Tensor) -> int:
    """Estimate the number of multiply-accumulation operations performed by the module

    Args:
        module (torch.nn.Module): PyTorch module
        inp (torch.Tensor): input to the module
        out (torch.Tensor): output of the module
    Returns:
        int: number of MACs
    """
    if isinstance(module, nn.Linear):
        return macs_linear(module, inp, out)
    if isinstance(module, (nn.Identity, nn.ReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, nn.Tanh, nn.Sigmoid, nn.Flatten)):
        return 0
    if isinstance(module, _ConvTransposeNd):
        return macs_convtransposend(module, inp, out)
    if isinstance(module, _ConvNd):
        return macs_convnd(module, inp, out)
    if isinstance(module, _BatchNorm):
        return macs_bn(module, inp, out)
    if isinstance(module, _MaxPoolNd):
        return macs_maxpool(module, inp, out)
    if isinstance(module, _AvgPoolNd):
        return macs_avgpool(module, inp, out)
    if isinstance(module, _AdaptiveMaxPoolNd):
        return macs_adaptive_maxpool(module, inp, out)
    if isinstance(module, _AdaptiveAvgPoolNd):
        return macs_adaptive_avgpool(module, inp, out)
    if isinstance(module, nn.Dropout):
        return 0
    warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1)
    return 0


def macs_linear(module: nn.Linear, _: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.Linear`"""
    # batch size * out_chan * macs_per_elt (bias already counted in accumulation)
    return module.in_features * reduce(mul, out.shape)


def macs_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.conv._ConvTransposeNd`"""
    # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)
    # Define min and max sizes, then subtract them
    padding_macs = len(module.kernel_size) * 4

    # Rest of the operations are almost identical to a convolution (given the padding)
    conv_macs = macs_convnd(module, inp, out)

    return padding_macs + conv_macs


def macs_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.conv._ConvNd`"""
    # For each position, # mult = kernel size, # adds = kernel size - 1
    window_macs_per_chan = reduce(mul, module.kernel_size)
    # Connections to input channels is controlled by the group parameter
    effective_in_chan = inp.shape[1] // module.groups
    # N * mac
    window_mac = effective_in_chan * window_macs_per_chan
    return out.numel() * window_mac

    # bias already counted in accumulation


def macs_bn(module: _BatchNorm, inp: Tensor, _: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""
    # sub mean, div by denom
    norm_mac = 1
    # mul by gamma, add beta
    scale_mac = 1 if module.affine else 0

    # Sum everything up
    bn_mac = inp.numel() * (norm_mac + scale_mac)

    # Count tracking stats update ops
    # cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L94-L101
    tracking_mac = 0
    b = inp.shape[0]
    num_spatial_elts = inp.shape[2:].numel()
    if module.track_running_stats and module.training:
        # running_mean: by channel, sum value and div by batch size
        tracking_mac += module.num_features * (b * num_spatial_elts - 1)
        # running_var: by channel, sub mean and square values, sum them, divide by batch size
        active_elts = b * num_spatial_elts
        tracking_mac += module.num_features * (2 * active_elts - 1)
        # Update both runnning stat: rescale previous value (mul by N), add it the new one, then div by (N + 1)
        tracking_mac += 2 * module.num_features * 2

    return bn_mac + tracking_mac


def macs_maxpool(module: _MaxPoolNd, _: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.pooling._MaxPoolNd`"""
    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

    # for each spatial output element, check max element in kernel scope
    return out.numel() * (k_size - 1)


def macs_avgpool(module: _AvgPoolNd, inp: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.pooling._AvgPoolNd`"""
    k_size = reduce(mul, module.kernel_size) if isinstance(module.kernel_size, tuple) else module.kernel_size

    # for each spatial output element, sum elements in kernel scope and div by kernel size
    return out.numel() * (k_size - 1 + inp.ndim - 2)


def macs_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inp: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.pooling._AdaptiveMaxPoolNd`"""
    # Approximate kernel_size using ratio of spatial shapes between input and output
    kernel_size = tuple(
        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
        for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False)
    )

    # for each spatial output element, check max element in kernel scope
    return out.numel() * (reduce(mul, kernel_size) - 1)


def macs_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inp: Tensor, out: Tensor) -> int:
    """MACs estimation for `torch.nn.modules.pooling._AdaptiveAvgPoolNd`"""
    # Approximate kernel_size using ratio of spatial shapes between input and output
    kernel_size = tuple(
        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
        for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False)
    )

    # for each spatial output element, sum elements in kernel scope and div by kernel size
    return out.numel() * (reduce(mul, kernel_size) - 1 + len(kernel_size))


================================================
FILE: torchscan/modules/memory.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import warnings
from functools import reduce
from operator import mul
from typing import Union

from torch import Tensor, nn
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd

__all__ = ["module_dmas"]


def module_dmas(module: Module, inp: Tensor, out: Tensor) -> int:
    """Estimate the number of direct memory accesses by the module.
    The implementation overhead is neglected.

    Args:
        module (torch.nn.Module): PyTorch module
        inp (torch.Tensor): input to the module
        out (torch.Tensor): output of the module
    Returns:
        int: number of DMAs
    """
    if isinstance(module, nn.Identity):
        return dmas_identity(module, inp, out)
    if isinstance(module, nn.Flatten):
        return dmas_flatten(module, inp, out)
    if isinstance(module, nn.Linear):
        return dmas_linear(module, inp, out)
    if isinstance(module, (nn.ReLU, nn.ReLU6)):
        return dmas_relu(module, inp, out)
    if isinstance(module, (nn.ELU, nn.LeakyReLU)):
        return dmas_act_single_param(module, inp, out)
    if isinstance(module, nn.Sigmoid):
        return dmas_sigmoid(module, inp, out)
    if isinstance(module, nn.Tanh):
        return dmas_tanh(module, inp, out)
    if isinstance(module, _ConvTransposeNd):
        return dmas_convtransposend(module, inp, out)
    if isinstance(module, _ConvNd):
        return dmas_convnd(module, inp, out)
    if isinstance(module, _BatchNorm):
        return dmas_bn(module, inp, out)
    if isinstance(module, (_MaxPoolNd, _AvgPoolNd)):
        return dmas_pool(module, inp, out)
    if isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
        return dmas_adaptive_pool(module, inp, out)
    if isinstance(module, nn.Dropout):
        return dmas_dropout(module, inp, out)
    warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1)
    return 0


def num_params(module: Module) -> int:
    """Compute the number of parameters

    Args:
        module (torch.nn.Module): PyTorch module
    Returns:
        int: number of parameter elements
    """
    return sum(p.data.numel() for p in module.parameters())


def dmas_identity(_: nn.Identity, inp: Tensor, __: Tensor) -> int:
    """DMAs estimation for `torch.nn.Identity`"""
    return inp.numel()


def dmas_flatten(_: nn.Flatten, inp: Tensor, __: Tensor) -> int:
    """DMAs estimation for `torch.nn.Flatten`"""
    return 2 * inp.numel()


def dmas_linear(module: nn.Linear, inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.Linear`"""
    input_dma = inp.numel()
    # Access weight and bias
    ops_dma = num_params(module)
    output_dma = out.numel()

    return input_dma + ops_dma + output_dma


def dmas_relu(module: Union[nn.ReLU, nn.ReLU6], inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.ReLU`"""
    input_dma = inp.numel()
    output_dma = 0 if module.inplace else out.numel()

    return input_dma + output_dma


def dmas_act_single_param(module: Union[nn.ELU, nn.LeakyReLU], inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for activations with single parameter"""
    input_dma = inp.numel()
    # Access alpha, slope or other
    ops_dma = 1
    output_dma = 0 if module.inplace else out.numel()

    return input_dma + ops_dma + output_dma


def dmas_sigmoid(_: nn.Sigmoid, inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.Sigmoid`"""
    # Access for both exp
    input_dma = inp.numel()
    output_dma = out.numel()

    return input_dma + output_dma


def dmas_tanh(_: nn.Tanh, inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.Tanh`"""
    # Access for both exp
    input_dma = inp.numel() * 2
    output_dma = out.numel()

    return input_dma + output_dma


def dmas_dropout(module: nn.Dropout, inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.Dropout`"""
    input_dma = inp.numel()

    # Access sampling probability
    ops_dma = 1

    output_dma = 0 if module.inplace else out.numel()

    return input_dma + ops_dma + output_dma


def dmas_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.modules.conv._ConvTransposeNd`"""
    # Padding (# cf. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L496-L532)
    # Access stride, padding and kernel_size
    in_padding = len(module.kernel_size) * 4
    out_padding = len(module.kernel_size)

    # The rest is like a classic convolution
    conv_dmas = dmas_convnd(module, inp, out)

    return in_padding + out_padding + conv_dmas


def dmas_convnd(module: _ConvNd, _: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.modules.conv._ConvNd`"""
    # Each output element required K ** 2 memory access of each input channel
    input_dma = module.in_channels * reduce(mul, module.kernel_size) * out.numel()
    # Correct with groups
    input_dma //= module.groups

    # Access weight & bias
    ops_dma = num_params(module)
    output_dma = out.numel()

    return input_dma + ops_dma + output_dma


def dmas_bn(module: _BatchNorm, inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for `torch.nn.modules.batchnorm._BatchNorm`"""
    input_dma = inp.numel()

    # Access running_mean, running_var and eps
    ops_dma = module.running_mean.numel() + module.running_var.numel() + 1  # type: ignore[union-attr]
    # Access to weight and bias
    if module.affine:
        ops_dma += module.weight.data.numel() + module.bias.data.numel()
    # Exp avg factor
    if module.momentum:
        ops_dma += 1
    # Update stats
    if module.training and module.track_running_stats:
        # Current mean and std computation only requires access to input, already counted in input_dma
        # Update num of batches and running stats
        ops_dma += 1 + module.running_mean.numel() + module.running_var.numel()  # type: ignore[union-attr]

    output_dma = out.numel()

    return input_dma + ops_dma + output_dma


def dmas_pool(module: Union[_MaxPoolNd, _AvgPoolNd], inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for spatial pooling modules"""
    # Resolve kernel size and stride size (can be stored as a single integer or a tuple)
    if isinstance(module.kernel_size, tuple):
        kernel_size = module.kernel_size
    elif isinstance(module.kernel_size, int):
        kernel_size = (module.kernel_size,) * (inp.ndim - 2)

    # Each output element required K ** 2 memory accesses
    input_dma = reduce(mul, kernel_size) * out.numel()

    output_dma = out.numel()

    return input_dma + output_dma


def dmas_adaptive_pool(_: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor) -> int:
    """DMAs estimation for adaptive spatial pooling modules"""
    # Approximate kernel_size using ratio of spatial shapes between input and output
    kernel_size = tuple(
        i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1
        for i_size, o_size in zip(inp.shape[2:], out.shape[2:], strict=False)
    )
    # Each output element required K ** 2 memory accesses
    input_dma = reduce(mul, kernel_size) * out.numel()

    output_dma = out.numel()

    return input_dma + output_dma


================================================
FILE: torchscan/modules/receptive.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import math
import warnings
from typing import Tuple, Union

from torch import Tensor, nn
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd
from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd

__all__ = ["module_rf"]


def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]:
    """Estimate the spatial receptive field of the module

    Args:
        module (torch.nn.Module): PyTorch module
        inp (torch.Tensor): input to the module
        out (torch.Tensor): output of the module
    Returns:
        receptive field
        effective stride
        effective padding
    """
    if isinstance(
        module,
        (
            nn.Identity,
            nn.Flatten,
            nn.ReLU,
            nn.ELU,
            nn.LeakyReLU,
            nn.ReLU6,
            nn.Tanh,
            nn.Sigmoid,
            _BatchNorm,
            nn.Dropout,
            nn.Linear,
        ),
    ):
        return 1.0, 1.0, 0.0
    if isinstance(module, _ConvTransposeNd):
        return rf_convtransposend(module, inp, out)
    if isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)):
        return rf_aggregnd(module, inp, out)
    if isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)):
        return rf_adaptive_poolnd(module, inp, out)
    warnings.warn(f"Module type not supported: {module.__class__.__name__}", stacklevel=1)
    return 1.0, 1.0, 0.0


def rf_convtransposend(module: _ConvTransposeNd, _: Tensor, __: Tensor) -> Tuple[float, float, float]:
    k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
    s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
    return -k, 1.0 / s, 0.0


def rf_aggregnd(module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor) -> Tuple[float, float, float]:
    k = module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size
    if hasattr(module, "dilation"):
        d = module.dilation[0] if isinstance(module.dilation, tuple) else module.dilation
        k = d * (k - 1) + 1
    s = module.stride[0] if isinstance(module.stride, tuple) else module.stride
    p = module.padding[0] if isinstance(module.padding, tuple) else module.padding
    return k, s, p  # type: ignore[return-value]


def rf_adaptive_poolnd(
    _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor
) -> Tuple[int, int, float]:
    stride = math.ceil(inp.shape[-1] / out.shape[-1])
    kernel_size = stride
    padding = (inp.shape[-1] - kernel_size * stride) / 2

    return kernel_size, stride, padding


================================================
FILE: torchscan/process/__init__.py
================================================
from .memory import *


================================================
FILE: torchscan/process/memory.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import re
import subprocess  # noqa S404
import warnings

import torch

__all__ = ["get_process_gpu_ram"]


def get_process_gpu_ram(pid: int) -> float:
    """Gets the amount of RAM used by a given process on GPU devices

    Args:
        pid: process ID
    Returns:
        RAM usage in Megabytes
    """
    # PyTorch is not responsible for GPU usage
    if not torch.cuda.is_available():
        warnings.warn("CUDA is unavailable to PyTorch.", stacklevel=1)
        return 0.0

    # Query the running processes on GPUs
    try:
        res = subprocess.run(["nvidia-smi", "-q", "-d", "PIDS"], capture_output=True).stdout.decode()
        # Try to locate the process
        pids = re.findall(r"Process ID\s+:\s([^\D]*)", res)
        for idx, _pid in enumerate(pids):
            if int(_pid) == pid:
                return float(re.findall(r"Used GPU Memory\s+:\s([^\D]*)", res)[idx])

        # Query total memory used by nvidia
        res = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.used", "--format=csv"], capture_output=True
        ).stdout.decode()
        return float(res.split("\n")[1].split()[0])
    except FileNotFoundError as e:
        warnings.warn(f"raised: {e}. Parsing NVIDIA-SMI failed.", stacklevel=1)

    # Default to overall RAM usage for this process on the GPU
    ram_str = torch.cuda.list_gpu_processes().split("\n")
    # Take the first process running on the GPU
    if ram_str[1].startswith("process"):
        return float(ram_str[1].split()[3])

    # Otherwise assume the process is running exclusively on CPU
    return 0.0


================================================
FILE: torchscan/utils.py
================================================
# Copyright (C) 2020-2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from itertools import starmap
from typing import Any, Dict, List, Optional, Tuple


def format_name(name: str, depth: int = 0) -> str:
    """Format a string for nested data printing

    Args:
        name: input string
        depth: depth of the nested information
    Returns:
        formatted string
    """
    if depth == 0:
        return name
    if depth == 1:
        return f"├─{name}"
    return f"{'|    ' * (depth - 1)}└─{name}"


def wrap_string(s: str, max_len: int, delimiter: str = ".", wrap: str = "[...]", mode: str = "end") -> str:
    """Wrap a string into a given length

    Args:
        s: input string
        max_len: maximum string length
        delimiter: character used for delimiting information categories
        wrap: wrapping sequence used
        mode: wrapping mode
    Returns:
        wrapped string
    """
    if len(s) <= max_len or mode is None:
        return s

    if mode == "end":
        return s[: max_len - len(wrap)] + wrap
    if mode == "mid":
        final_part = s.rpartition(delimiter)[-1]
        wrapped_end = f"{wrap}.{final_part}"
        return s[: max_len - len(wrapped_end)] + wrapped_end
    raise ValueError("received an unexpected value of argument `mode`")


def unit_scale(val: float) -> Tuple[float, str]:
    """Rescale value using scale units

    Args:
        val: input value
    Returns:
        tuple of rescaled value and unit
    """
    if val // 1e12 > 0:
        return val / 1e12, "T"
    if val // 1e9 > 0:
        return val / 1e9, "G"
    if val // 1e6 > 0:
        return val / 1e6, "M"
    if val // 1e3 > 0:
        return val / 1e3, "k"
    return val, ""


def format_s(f_string: str, min_w: Optional[int] = None, max_w: Optional[int] = None) -> str:
    """Format number strings"""
    if isinstance(min_w, int):
        f_string = f"{f_string:<{min_w}}"
    if isinstance(max_w, int):
        f_string = f"{f_string:.{max_w}}"

    return f_string


def format_line_str(
    layer: Dict[str, Any],
    col_w: Optional[List[int]] = None,
    wrap_mode: str = "mid",
    receptive_field: bool = False,
    effective_rf_stats: bool = False,
) -> List[str]:
    """Wrap all information into multiple lines"""
    if not isinstance(col_w, list):
        col_w = [None] * 7  # type: ignore[list-item]

    max_len = col_w[0] + 3 if isinstance(col_w[0], int) else 100
    line_str = [
        format_s(wrap_string(format_name(layer["name"], layer["depth"]), max_len, mode=wrap_mode), col_w[0], col_w[0]),
        format_s(layer["type"], col_w[1], col_w[1]),
        format_s(str(layer["output_shape"]), col_w[2], col_w[2]),
        format_s(f"{layer['grad_params'] + layer['nograd_params'] + layer['num_buffers']:,}", col_w[3], col_w[3]),
    ]

    if receptive_field:
        line_str.append(format_s(f"{layer['rf']:.0f}", col_w[4], col_w[4]))
        if effective_rf_stats:
            line_str.extend((
                format_s(f"{layer['s']:.0f}", col_w[5], col_w[5]),
                format_s(f"{layer['p']:.0f}", col_w[6], col_w[6]),
            ))

    return line_str


def format_info(
    module_info: Dict[str, Any], wrap_mode: str = "mid", receptive_field: bool = False, effective_rf_stats: bool = False
) -> str:
    """Print module summary for an expected input tensor shape

    Args:
        module_info: dictionary output of `crawl_module`
        wrap_mode: wrapping mode
        receptive_field: whether to display receptive field
        effective_rf_stats: if `receptive_field` is True, displays effective stride and padding
    Returns:
        formatted information
    """
    # Set margin between cols
    margin = 4
    # Dynamic col width
    # Init with headers
    headers = ["Layer", "Type", "Output Shape", "Param #", "Receptive field", "Effective stride", "Effective padding"]
    max_w = [27, 20, 25, 15, 15, 16, 17]
    col_w = [len(s) for s in headers]
    for layer in module_info["layers"]:
        col_w = [
            max(v, len(s))
            for v, s in zip(
                col_w,
                format_line_str(layer, col_w=None, wrap_mode=wrap_mode, receptive_field=True, effective_rf_stats=True),
                strict=False,
            )
        ]

    # Truncate columns that are too long
    col_w = list(starmap(min, zip(col_w, max_w, strict=False)))

    if not receptive_field:
        col_w = col_w[:4]
        headers = headers[:4]
    elif not effective_rf_stats:
        col_w = col_w[:5]
        headers = headers[:5]

    # Define separating lines
    line_length = sum(col_w) + (len(col_w) - 1) * margin
    thin_line = "_" * line_length
    thick_line = "=" * line_length
    dot_line = "-" * line_length

    margin_str = " " * margin

    # Header
    info_str = [
        thin_line,
        margin_str.join([f"{col_name:<{col_w}}" for col_name, col_w in zip(headers, col_w, strict=False)]),
        thick_line,
    ]

    # Layers
    for layer in module_info["layers"]:
        line_str = format_line_str(layer, col_w, wrap_mode, receptive_field, effective_rf_stats)
        info_str.append((" " * margin).join(line_str))

    # Parameter information
    num_params = module_info["overall"]["grad_params"] + module_info["overall"]["nograd_params"]
    info_str.extend((
        thick_line,
        f"Trainable params: {module_info['overall']['grad_params']:,}",
        f"Non-trainable params: {module_info['overall']['nograd_params']:,}",
        f"Total params: {num_params:,}",
    ))

    # Static RAM usage
    info_str.append(dot_line)

    # Convert to Megabytes
    param_size = (module_info["overall"]["param_size"] + module_info["overall"]["buffer_size"]) / 1024**2
    overhead = module_info["overheads"]["framework"]["fwd"] + module_info["overheads"]["cuda"]["fwd"]

    info_str.extend((
        f"Model size (params + buffers): {param_size:.2f} Mb",
        f"Framework & CUDA overhead: {overhead:.2f} Mb",
        f"Total RAM usage: {param_size + overhead:.2f} Mb",
    ))

    # FLOPS information
    info_str.append(dot_line)

    flops, flops_units = unit_scale(sum(layer["flops"] for layer in module_info["layers"]))
    macs, macs_units = unit_scale(sum(layer["macs"] for layer in module_info["layers"]))
    dmas, dmas_units = unit_scale(sum(layer["dmas"] for layer in module_info["layers"]))

    info_str.extend((
        f"Floating Point Operations on forward: {flops:.2f} {flops_units}FLOPs",
        f"Multiply-Accumulations on forward: {macs:.2f} {macs_units}MACs",
        f"Direct memory accesses on forward: {dmas:.2f} {dmas_units}DMAs",
        thin_line,
    ))

    return "\n".join(info_str)


def aggregate_info(info: Dict[str, Any], max_depth: int) -> Dict[str, Any]:
    """Aggregate module information to a maximum depth

    Args:
        info: dictionary output of `crawl_module`
        max_depth: depth at which parent node aggregates children information
    Returns:
        edited dictionary information
    """
    if not any(layer["depth"] == max_depth for layer in info["layers"]):
        raise ValueError("The `max_depth` argument cannot be higher than module depth.")

    for fw_idx, layer in enumerate(info["layers"]):
        # Need to aggregate information
        if not layer["is_leaf"] and layer["depth"] == max_depth:
            grad_p, nograd_p, p_size, num_buffers, b_size = 0, 0, 0, 0, 0
            flops, macs, dmas = 0, 0, 0
            for _layer in info["layers"][fw_idx + 1 :]:
                # Children have superior depth and were hooked after parent
                if _layer["depth"] <= max_depth:
                    break
                # Aggregate all information (flops, macc, ram)
                flops += _layer["flops"]
                macs += _layer["macs"]
                dmas += _layer["dmas"]
                grad_p += _layer["grad_params"]
                nograd_p += _layer["nograd_params"]
                p_size += _layer["param_size"]
                num_buffers += _layer["num_buffers"]
                b_size += _layer["buffer_size"]
                # Take last child effective RF
                _rf, _s, _p = _layer["rf"], _layer["s"], _layer["p"]

            # Update info
            info["layers"][fw_idx]["flops"] = flops
            info["layers"][fw_idx]["macs"] = macs
            info["layers"][fw_idx]["dmas"] = dmas
            info["layers"][fw_idx]["rf"] = _rf
            info["layers"][fw_idx]["s"] = _s
            info["layers"][fw_idx]["p"] = _p
            info["layers"][fw_idx]["grad_params"] = grad_p
            info["layers"][fw_idx]["nograd_params"] = nograd_p
            info["layers"][fw_idx]["param_size"] = p_size
            info["layers"][fw_idx]["num_buffers"] = num_buffers
            info["layers"][fw_idx]["buffer_size"] = b_size

    # Filter out further depth information
    info["layers"] = [layer for layer in info["layers"] if layer["depth"] <= max_depth]

    return info
Download .txt
gitextract_xv6br66t/

├── .conda/
│   └── meta.yaml
├── .github/
│   ├── FUNDING.yml
│   ├── ISSUE_TEMPLATE/
│   │   ├── bug_report.yml
│   │   ├── config.yml
│   │   └── feature_request.yml
│   ├── PULL_REQUEST_TEMPLATE.md
│   ├── collect_env.py
│   ├── dependabot.yml
│   ├── labeler.yml
│   ├── release.yml
│   ├── verify_labels.py
│   └── workflows/
│       ├── builds.yml
│       ├── doc-status.yml
│       ├── docs.yml
│       ├── pr-labels.yml
│       ├── publish.yml
│       ├── pull_requests.yml
│       ├── style.yml
│       └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── docs/
│   ├── Makefile
│   ├── README.md
│   ├── build.sh
│   ├── make.bat
│   └── source/
│       ├── _static/
│       │   ├── css/
│       │   │   └── custom.css
│       │   └── js/
│       │       └── custom.js
│       ├── changelog.rst
│       ├── conf.py
│       ├── index.rst
│       ├── installing.rst
│       ├── modules.rst
│       ├── process.rst
│       ├── torchscan.rst
│       └── utils.rst
├── pyproject.toml
├── scripts/
│   └── benchmark.py
├── setup.py
├── tests/
│   ├── test_crawler.py
│   ├── test_modules.py
│   ├── test_process.py
│   └── test_utils.py
└── torchscan/
    ├── __init__.py
    ├── crawler.py
    ├── modules/
    │   ├── __init__.py
    │   ├── flops.py
    │   ├── macs.py
    │   ├── memory.py
    │   └── receptive.py
    ├── process/
    │   ├── __init__.py
    │   └── memory.py
    └── utils.py
Download .txt
SYMBOL INDEX (106 symbols across 16 files)

FILE: .github/collect_env.py
  class SystemEnv (line 40) | class SystemEnv(NamedTuple):
  function run (line 52) | def run(command):
  function run_and_read_all (line 64) | def run_and_read_all(run_lambda, command):
  function run_and_parse_first_match (line 72) | def run_and_parse_first_match(run_lambda, command, regex):
  function get_nvidia_driver_version (line 83) | def get_nvidia_driver_version(run_lambda):
  function get_gpu_info (line 91) | def get_gpu_info(run_lambda):
  function get_running_cuda_version (line 105) | def get_running_cuda_version(run_lambda):
  function get_cudnn_version (line 109) | def get_cudnn_version(run_lambda):
  function get_nvidia_smi (line 143) | def get_nvidia_smi():
  function get_platform (line 159) | def get_platform():
  function get_mac_version (line 171) | def get_mac_version(run_lambda):
  function get_windows_version (line 175) | def get_windows_version(run_lambda):
  function get_lsb_version (line 179) | def get_lsb_version(run_lambda):
  function check_release_file (line 183) | def check_release_file(run_lambda):
  function get_os (line 187) | def get_os(run_lambda):
  function get_env_info (line 216) | def get_env_info():
  function pretty_str (line 255) | def pretty_str(envinfo):
  function get_pretty_env_info (line 303) | def get_pretty_env_info():
  function main (line 312) | def main():

FILE: .github/verify_labels.py
  function query_repo (line 47) | def query_repo(cmd: str, *, accept) -> Any:
  function get_pr_merger_and_labels (line 54) | def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]:
  function main (line 62) | def main(args):
  function parse_args (line 69) | def parse_args():

FILE: docs/source/_static/js/custom.js
  function addGithubButton (line 15) | function addGithubButton() {
  function addVersionControl (line 29) | function addVersionControl() {
  function parseGithubButtons (line 95) | function parseGithubButtons (){"use strict";var e=window.document,t=e.lo...
  function onLoad (line 97) | function onLoad() {

FILE: docs/source/conf.py
  function add_ga_javascript (line 111) | def add_ga_javascript(app, pagename, templatename, context, doctree):
  function setup (line 126) | def setup(app):

FILE: scripts/benchmark.py
  function main (line 54) | def main():

FILE: tests/test_crawler.py
  function test_apply (line 11) | def test_apply():
  function test_crawl_module (line 26) | def test_crawl_module():
  function test_summary (line 35) | def test_summary():

FILE: tests/test_modules.py
  class MyModule (line 8) | class MyModule(nn.Module):
    method __init__ (line 9) | def __init__(self):
  function test_module_flops_warning (line 13) | def test_module_flops_warning():
  function test_module_flops (line 53) | def test_module_flops(mod, input_shape, output_shape, expected_val):
  function test_transformer_flops (line 57) | def test_transformer_flops():
  function test_module_macs_warning (line 64) | def test_module_macs_warning():
  function test_module_macs (line 95) | def test_module_macs(mod, input_shape, output_shape, expected_val):
  function test_module_dmas_warning (line 99) | def test_module_dmas_warning():
  function test_module_dmas (line 134) | def test_module_dmas(mod, input_shape, output_shape, expected_val):

FILE: tests/test_process.py
  function test_get_process_gpu_ram (line 8) | def test_get_process_gpu_ram():

FILE: tests/test_utils.py
  function test_format_name (line 6) | def test_format_name():
  function test_wrap_string (line 13) | def test_wrap_string():
  function test_unit_scale (line 35) | def test_unit_scale(input_val, num_val, unit):

FILE: torchscan/crawler.py
  function apply (line 19) | def apply(module: Module, fn: Callable[[Module, str], None], name: Optio...
  function crawl_module (line 34) | def crawl_module(
  function summary (line 268) | def summary(

FILE: torchscan/modules/flops.py
  function module_flops (line 21) | def module_flops(module: Module, inputs: Tuple[Tensor, ...], out: Tensor...
  function flops_linear (line 69) | def flops_linear(module: nn.Linear, inputs: Tuple[Tensor, ...]) -> int:
  function flops_sigmoid (line 79) | def flops_sigmoid(_: nn.Sigmoid, inputs: Tuple[Tensor, ...]) -> int:
  function flops_relu (line 85) | def flops_relu(_: nn.ReLU, inputs: Tuple[Tensor, ...]) -> int:
  function flops_elu (line 91) | def flops_elu(_: nn.ELU, inputs: Tuple[Tensor, ...]) -> int:
  function flops_leakyrelu (line 97) | def flops_leakyrelu(_: nn.LeakyReLU, inputs: Tuple[Tensor, ...]) -> int:
  function flops_relu6 (line 103) | def flops_relu6(_: nn.ReLU6, inputs: Tuple[Tensor, ...]) -> int:
  function flops_tanh (line 109) | def flops_tanh(_: nn.Tanh, inputs: Tuple[Tensor, ...]) -> int:
  function flops_dropout (line 115) | def flops_dropout(module: nn.Dropout, inputs: Tuple[Tensor, ...]) -> int:
  function flops_convtransposend (line 123) | def flops_convtransposend(module: _ConvTransposeNd, inputs: Tuple[Tensor...
  function flops_convnd (line 135) | def flops_convnd(module: _ConvNd, inputs: Tuple[Tensor, ...], out: Tenso...
  function flops_bn (line 151) | def flops_bn(module: _BatchNorm, inputs: Tuple[Tensor, ...]) -> int:
  function flops_maxpool (line 178) | def flops_maxpool(module: _MaxPoolNd, _: Tuple[Tensor, ...], out: Tensor...
  function flops_avgpool (line 186) | def flops_avgpool(module: _AvgPoolNd, inputs: Tuple[Tensor, ...], out: T...
  function flops_adaptive_maxpool (line 194) | def flops_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inputs: Tuple[Tensor, ...
  function flops_adaptive_avgpool (line 206) | def flops_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inputs: Tuple[Tensor, ...
  function flops_layernorm (line 218) | def flops_layernorm(module: nn.LayerNorm, inputs: Tuple[Tensor, ...]) ->...
  function flops_mha (line 234) | def flops_mha(module: nn.MultiheadAttention, inputs: Tuple[Tensor, ...])...
  function flops_transformer_encoderlayer (line 287) | def flops_transformer_encoderlayer(module: nn.TransformerEncoderLayer, i...
  function flops_transformer_decoderlayer (line 304) | def flops_transformer_decoderlayer(module: nn.TransformerDecoderLayer, i...
  function flops_transformer (line 326) | def flops_transformer(module: nn.Transformer, inputs: Tuple[Tensor, ...]...

FILE: torchscan/modules/macs.py
  function module_macs (line 19) | def module_macs(module: Module, inp: Tensor, out: Tensor) -> int:
  function macs_linear (line 53) | def macs_linear(module: nn.Linear, _: Tensor, out: Tensor) -> int:
  function macs_convtransposend (line 59) | def macs_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Ten...
  function macs_convnd (line 71) | def macs_convnd(module: _ConvNd, inp: Tensor, out: Tensor) -> int:
  function macs_bn (line 84) | def macs_bn(module: _BatchNorm, inp: Tensor, _: Tensor) -> int:
  function macs_maxpool (line 111) | def macs_maxpool(module: _MaxPoolNd, _: Tensor, out: Tensor) -> int:
  function macs_avgpool (line 119) | def macs_avgpool(module: _AvgPoolNd, inp: Tensor, out: Tensor) -> int:
  function macs_adaptive_maxpool (line 127) | def macs_adaptive_maxpool(_: _AdaptiveMaxPoolNd, inp: Tensor, out: Tenso...
  function macs_adaptive_avgpool (line 139) | def macs_adaptive_avgpool(_: _AdaptiveAvgPoolNd, inp: Tensor, out: Tenso...

FILE: torchscan/modules/memory.py
  function module_dmas (line 20) | def module_dmas(module: Module, inp: Tensor, out: Tensor) -> int:
  function num_params (line 61) | def num_params(module: Module) -> int:
  function dmas_identity (line 72) | def dmas_identity(_: nn.Identity, inp: Tensor, __: Tensor) -> int:
  function dmas_flatten (line 77) | def dmas_flatten(_: nn.Flatten, inp: Tensor, __: Tensor) -> int:
  function dmas_linear (line 82) | def dmas_linear(module: nn.Linear, inp: Tensor, out: Tensor) -> int:
  function dmas_relu (line 92) | def dmas_relu(module: Union[nn.ReLU, nn.ReLU6], inp: Tensor, out: Tensor...
  function dmas_act_single_param (line 100) | def dmas_act_single_param(module: Union[nn.ELU, nn.LeakyReLU], inp: Tens...
  function dmas_sigmoid (line 110) | def dmas_sigmoid(_: nn.Sigmoid, inp: Tensor, out: Tensor) -> int:
  function dmas_tanh (line 119) | def dmas_tanh(_: nn.Tanh, inp: Tensor, out: Tensor) -> int:
  function dmas_dropout (line 128) | def dmas_dropout(module: nn.Dropout, inp: Tensor, out: Tensor) -> int:
  function dmas_convtransposend (line 140) | def dmas_convtransposend(module: _ConvTransposeNd, inp: Tensor, out: Ten...
  function dmas_convnd (line 153) | def dmas_convnd(module: _ConvNd, _: Tensor, out: Tensor) -> int:
  function dmas_bn (line 167) | def dmas_bn(module: _BatchNorm, inp: Tensor, out: Tensor) -> int:
  function dmas_pool (line 190) | def dmas_pool(module: Union[_MaxPoolNd, _AvgPoolNd], inp: Tensor, out: T...
  function dmas_adaptive_pool (line 206) | def dmas_adaptive_pool(_: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd],...

FILE: torchscan/modules/receptive.py
  function module_rf (line 19) | def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, ...
  function rf_convtransposend (line 58) | def rf_convtransposend(module: _ConvTransposeNd, _: Tensor, __: Tensor) ...
  function rf_aggregnd (line 64) | def rf_aggregnd(module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tenso...
  function rf_adaptive_poolnd (line 74) | def rf_adaptive_poolnd(

FILE: torchscan/process/memory.py
  function get_process_gpu_ram (line 15) | def get_process_gpu_ram(pid: int) -> float:

FILE: torchscan/utils.py
  function format_name (line 10) | def format_name(name: str, depth: int = 0) -> str:
  function wrap_string (line 26) | def wrap_string(s: str, max_len: int, delimiter: str = ".", wrap: str = ...
  function unit_scale (line 50) | def unit_scale(val: float) -> Tuple[float, str]:
  function format_s (line 69) | def format_s(f_string: str, min_w: Optional[int] = None, max_w: Optional...
  function format_line_str (line 79) | def format_line_str(
  function format_info (line 109) | def format_info(
  function aggregate_info (line 208) | def aggregate_info(info: Dict[str, Any], max_depth: int) -> Dict[str, Any]:
Condensed preview — 57 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (183K chars).
[
  {
    "path": ".conda/meta.yaml",
    "chars": 1161,
    "preview": "{% set pyproject = load_file_data('../pyproject.toml', from_recipe_dir=True) %}\n{% set project = pyproject.get('project'"
  },
  {
    "path": ".github/FUNDING.yml",
    "chars": 631,
    "preview": "# These are supported funding model platforms\n\ngithub: frgfm\npatreon: # Replace with a single Patreon username\nopen_coll"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yml",
    "chars": 1932,
    "preview": "name: 🐛 Bug report\ndescription: Create a report to help us improve the library\nlabels: 'type: bug'\nassignees: frgfm\n\nbod"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "chars": 200,
    "preview": "blank_issues_enabled: true\ncontact_links:\n  - name: Usage questions\n    url: https://github.com/frgfm/torch-scan/discuss"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "chars": 1037,
    "preview": "name: 🚀 Feature request\ndescription: Submit a proposal/request for a new feature\nlabels: 'type: enhancement'\nassignees: "
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "chars": 1107,
    "preview": "# What does this PR do?\n\n<!--\nWell, hello there! Thank you for proposing modifications to the project.\n\nMake sure to hav"
  },
  {
    "path": ".github/collect_env.py",
    "chars": 9800,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": ".github/dependabot.yml",
    "chars": 957,
    "preview": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where "
  },
  {
    "path": ".github/labeler.yml",
    "chars": 932,
    "preview": "'module: crawler':\n- changed-files:\n  - any-glob-to-any-file: torchscan/crawler.py\n\n'module: modules':\n- changed-files:\n"
  },
  {
    "path": ".github/release.yml",
    "chars": 480,
    "preview": "changelog:\n  exclude:\n    labels:\n      - ignore-for-release\n  categories:\n    - title: Breaking Changes 🛠\n      labels:"
  },
  {
    "path": ".github/verify_labels.py",
    "chars": 2525,
    "preview": "# Copyright (C) 2022-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": ".github/workflows/builds.yml",
    "chars": 1899,
    "preview": "name: builds\n\non:\n  push:\n    branches: main\n  pull_request:\n    branches: main\n\njobs:\n  build:\n    runs-on: ${{ matrix."
  },
  {
    "path": ".github/workflows/doc-status.yml",
    "chars": 656,
    "preview": "name: GH-Pages Status\non:\n  page_build\n\njobs:\n  see-page-build-payload:\n    runs-on: ubuntu-latest\n    steps:\n      - us"
  },
  {
    "path": ".github/workflows/docs.yml",
    "chars": 1118,
    "preview": "name: docs\non:\n  push:\n    branches: main\n\njobs:\n  docs-deploy:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix"
  },
  {
    "path": ".github/workflows/pr-labels.yml",
    "chars": 1120,
    "preview": "name: pr-labels\n\non:\n  pull_request:\n    branches: main\n    types: closed\n\njobs:\n  is-properly-labeled:\n    if: github.e"
  },
  {
    "path": ".github/workflows/publish.yml",
    "chars": 2629,
    "preview": "name: publish\n\non:\n  release:\n    types: [published]\n\njobs:\n  pypi:\n    if: \"!github.event.release.prerelease\"\n    runs-"
  },
  {
    "path": ".github/workflows/pull_requests.yml",
    "chars": 772,
    "preview": "name: pull_requests\n\non:\n  pull_request:\n    branches: main\n\njobs:\n  docs:\n    runs-on: ubuntu-latest\n    steps:\n      -"
  },
  {
    "path": ".github/workflows/style.yml",
    "chars": 2094,
    "preview": "name: style\n\non:\n  push:\n    branches: main\n  pull_request:\n    branches: main\n\njobs:\n  ruff:\n    runs-on: ${{ matrix.os"
  },
  {
    "path": ".github/workflows/tests.yml",
    "chars": 1675,
    "preview": "name: tests\n\non:\n  push:\n    branches: main\n  pull_request:\n    branches: main\n\njobs:\n  pytest:\n    runs-on: ${{ matrix."
  },
  {
    "path": ".gitignore",
    "chars": 1884,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 793,
    "preview": "default_language_version:\n    python: python3.11\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 5228,
    "preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 4312,
    "preview": "# Contributing to torchscan\n\nEverything you need to know to contribute efficiently to the project.\n\nWhatever the way you"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "Makefile",
    "chars": 440,
    "preview": "# this target runs checks on all files\nquality:\n\truff format --check .\n\truff check .\n\tmypy\n\n# this target runs checks on"
  },
  {
    "path": "README.md",
    "chars": 11835,
    "preview": "<p align=\"center\">\n  <img src=\"https://github.com/frgfm/torch-scan/releases/download/v0.1.1/logo_text.png\" width=\"30%\">\n"
  },
  {
    "path": "docs/Makefile",
    "chars": 638,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
  },
  {
    "path": "docs/README.md",
    "chars": 1100,
    "preview": "# Changing the documentation\n\nThe documentation of this project is built using `sphinx`. In order to install all the bui"
  },
  {
    "path": "docs/build.sh",
    "chars": 1219,
    "preview": "function deploy_doc(){\n    if [ ! -z \"$1\" ]\n    then\n        git checkout $1\n    fi\n    COMMIT=$(git rev-parse --short H"
  },
  {
    "path": "docs/make.bat",
    "chars": 799,
    "preview": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sp"
  },
  {
    "path": "docs/source/_static/css/custom.css",
    "chars": 682,
    "preview": "h1 {\n    font-size: 200%;\n}\n\n/* Github button */\n\n.github-repo {\n    display: flex;\n    justify-content: center;\n}\n\n/* V"
  },
  {
    "path": "docs/source/_static/js/custom.js",
    "chars": 13806,
    "preview": "// Based on https://github.com/huggingface/transformers/blob/master/docs/source/_static/js/custom.js\n\n\n// These two thin"
  },
  {
    "path": "docs/source/changelog.rst",
    "chars": 390,
    "preview": "Changelog\n=========\n\n\nv0.1.2 (2022-08-03)\n-------------------\nRelease note: `v0.1.2 <https://github.com/frgfm/torch-scan"
  },
  {
    "path": "docs/source/conf.py",
    "chars": 5024,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "docs/source/index.rst",
    "chars": 4166,
    "preview": "**************************************\nTorchScan: inspect your PyTorch models\n**************************************\n\nTh"
  },
  {
    "path": "docs/source/installing.rst",
    "chars": 650,
    "preview": "\n************\nInstallation\n************\n\nThis library requires `Python <https://www.python.org/downloads/>`_ 3.6 or high"
  },
  {
    "path": "docs/source/modules.rst",
    "chars": 633,
    "preview": "torchscan.modules\n=================\n\nThe modules subpackage contains tools for inspection of modules.\n\n.. currentmodule:"
  },
  {
    "path": "docs/source/process.rst",
    "chars": 259,
    "preview": "torchscan.process\n=================\n\nThe process subpackage contains tools regarding active Python processes.\n\nThe follo"
  },
  {
    "path": "docs/source/torchscan.rst",
    "chars": 127,
    "preview": "torchscan\n=========\n\n\n.. currentmodule:: torchscan\n\n\nCrawler\n~~~~~~~\n\n.. autofunction:: crawl_module\n.. autofunction:: s"
  },
  {
    "path": "docs/source/utils.rst",
    "chars": 133,
    "preview": "torchscan.utils\n===============\n\n.. currentmodule:: torchscan.utils\n\n.. autofunction:: format_info\n\n.. autofunction:: ag"
  },
  {
    "path": "pyproject.toml",
    "chars": 5403,
    "preview": "[build-system]\nrequires = [\"setuptools\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"torchscan\"\n"
  },
  {
    "path": "scripts/benchmark.py",
    "chars": 2293,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "setup.py",
    "chars": 697,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "tests/test_crawler.py",
    "chars": 2559,
    "preview": "import io\nimport sys\nfrom collections import OrderedDict\n\nimport pytest\nimport torch.nn as nn\n\nfrom torchscan import cra"
  },
  {
    "path": "tests/test_modules.py",
    "chars": 7713,
    "preview": "import pytest\nimport torch\nfrom torch import nn\n\nfrom torchscan import modules\n\n\nclass MyModule(nn.Module):\n    def __in"
  },
  {
    "path": "tests/test_process.py",
    "chars": 255,
    "preview": "import os\n\nimport torch\n\nfrom torchscan import process\n\n\ndef test_get_process_gpu_ram():\n    if torch.cuda.is_initialize"
  },
  {
    "path": "tests/test_utils.py",
    "chars": 1066,
    "preview": "import pytest\n\nfrom torchscan import utils\n\n\ndef test_format_name():\n    name = \"mymodule\"\n    assert utils.format_name("
  },
  {
    "path": "torchscan/__init__.py",
    "chars": 176,
    "preview": "from contextlib import suppress\nfrom torchscan import modules, process, utils\nfrom torchscan.crawler import *\n\nwith supp"
  },
  {
    "path": "torchscan/crawler.py",
    "chars": 11995,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "torchscan/modules/__init__.py",
    "chars": 88,
    "preview": "from .flops import *\nfrom .macs import *\nfrom .memory import *\nfrom .receptive import *\n"
  },
  {
    "path": "torchscan/modules/flops.py",
    "chars": 14748,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "torchscan/modules/macs.py",
    "chars": 6380,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "torchscan/modules/memory.py",
    "chars": 7641,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "torchscan/modules/receptive.py",
    "chars": 2951,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "torchscan/process/__init__.py",
    "chars": 22,
    "preview": "from .memory import *\n"
  },
  {
    "path": "torchscan/process/memory.py",
    "chars": 1799,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  },
  {
    "path": "torchscan/utils.py",
    "chars": 9082,
    "preview": "# Copyright (C) 2020-2024, François-Guillaume Fernandez.\n\n# This program is licensed under the Apache License 2.0.\n# See"
  }
]

About this extraction

This page contains the full source code of the frgfm/torch-scan GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 57 files (169.0 KB), approximately 49.3k tokens, and a symbol index with 106 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!