[
  {
    "path": ".circleci/config.yml",
    "content": "version: 2.1\naliases:\n  - &container_python\n    docker:\n      - image: cimg/python:3.8.4  # primary container for the build job\n\n  - &run_task_install_tox_dependencies\n    run:\n      name: install tox dependencies\n      command: |\n        sudo apt-get update\n        sudo apt install -y build-essential libssl-dev libpython-dev python python-pip\n        sudo -H pip install --upgrade pip tox virtualenv\n\norbs:\n  codecov: codecov/codecov@1.0.4\njobs:\n  testing:\n    parameters:\n      tests:\n        type: string\n        default: py38-torch10,py38-torch11,py38-torch14,py38-torch17\n    <<: *container_python\n    steps:\n      - checkout\n      - *run_task_install_tox_dependencies\n      - run:\n          name: execute pytests << parameters.tests >>\n          no_output_timeout: 30m\n          command: |\n            mkdir test-reports\n            tox -e << parameters.tests >>\n      - codecov/upload:\n          flags: backend,unittest\n      - store_artifacts:\n          path: htmlcov\n      - store_test_results:\n          path: test-reports\n      - codecov/upload:\n          file: coverage/*.json\n          flags: frontend\n  builddocs:\n    <<: *container_python\n    steps:\n      - checkout\n      - *run_task_install_tox_dependencies\n      - run:\n          name: build the sphinx documentation\n          command: |\n            tox -e docs\n  conda_deploy:\n    parameters:\n      versions:\n        type: string\n        default: \"3.7 3.8\"\n    docker:\n      - image: continuumio/miniconda3\n    steps:\n      - checkout\n      - run:\n          name: install conda dependencies\n          command: |\n            conda install conda-build anaconda-client conda-verify -y\n      - run:\n          name: generate skeleton file from PyPI and complete recipe\n          command: |\n            cd ~\n            conda skeleton pypi memcnn\n            cd memcnn\n            python -c \"f = open('meta.yaml', 'r'); data = f.read(); f.close(); data=data.replace(' torch ', ' pytorch ').replace('your-github-id-here', 'silvandeleemput').replace('pillow\\n    - python', 'pillow\\n    - pip\\n    - python').replace(' pip', ' pip >=18.0'); f = open('meta.yaml', 'w'); f.write(data); f.close();\"\n            cat ~/memcnn/meta.yaml\n      - run:\n          name: build binary artifacts for python versions << parameters.versions >>\n          no_output_timeout: 30m\n          command: |\n            cd ~/memcnn\n            PYTHON_VERSIONS=( << parameters.versions >> )\n            for i in \"${PYTHON_VERSIONS[@]}\"\n            do\n                echo $i\n                conda-build -c conda-forge -c simpleitk -c pytorch --numpy 1.15.1 --python $i .\n            done\n      - run:\n          name: upload binary artifacts for all platforms to anaconda cloud\n          command: |\n            anaconda login --user=silvandeleemput --password=$CONDA_PASSWORD\n            find /opt/conda/conda-bld/ -name *.tar.bz2 | while read file\n            do\n                echo $file\n                anaconda upload $file --skip-existing --all\n            done\n  deploy:\n    docker:\n      - image: cimg/python:3.8.4\n    steps:\n      - checkout\n      - restore_cache:\n          key: v1-dependency-cache-{{ checksum \"setup.py\" }}\n      - run:\n          name: install python dependencies\n          command: |\n            python3 -m venv venv\n            . venv/bin/activate\n            pip install --upgrade pip\n            pip install pylint doc8 coverage codecov twine\n            pip install -e .\n      - save_cache:\n          key: v1-dependency-cache-{{ checksum \"setup.py\" }}\n          paths:\n            - \"venv\"\n      - run:\n          name: verify git tag vs. version\n          command: |\n            python3 -m venv venv\n            . venv/bin/activate\n            python setup.py verify\n      - run:\n          name: init .pypirc\n          command: |\n            echo -e \"[pypi]\" >> ~/.pypirc\n            echo -e \"username = Sil\" >> ~/.pypirc\n            echo -e \"password = $PYPI_PASSWORD\" >> ~/.pypirc\n      - run:\n          name: createpackages\n          command: |\n            python setup.py sdist\n            python setup.py bdist_wheel\n      - run:\n          name: upload to pypi\n          command: |\n            . venv/bin/activate\n            twine upload dist/*\n      - run:\n          name: trigger docker hub master branch build\n          command: |\n            curl -H \"Content-Type: application/json\" --data '{\"source_type\": \"Branch\", \"source_name\": \"master\"}' -X POST $DOCKER_TRIGGER_URL\n      - run:\n          name: trigger docker hub latest tag build\n          command: |\n            curl -H \"Content-Type: application/json\" --data '{\"source_type\": \"Tag\", \"source_name\": \"'\"$CIRCLE_TAG\"'\"}' -X POST $DOCKER_TRIGGER_URL\n\nworkflows:\n  version: 2\n  build_test_and_deploy:\n    jobs:\n      - testing:\n          name: testing_py38_torch14\n          tests: py38-torch14\n          filters:\n            tags:\n              only: /.*/\n      - testing:\n          name: testing_py38_torch17\n          tests: py38-torch17\n          filters:\n            tags:\n              only: /.*/\n      - deploy:\n          requires:\n            - testing_py38_torch14\n            - testing_py38_torch17\n          filters:\n            tags:\n              only: /[0-9]+(\\.[0-9]+)*/\n            branches:\n              ignore: /.*/\n      - conda_deploy:\n          name: conda_deploy_py37\n          requires:\n            - deploy\n          versions: \"3.7\"\n          filters:\n            tags:\n              only: /[0-9]+(\\.[0-9]+)*/\n            branches:\n              ignore: /.*/\n      - conda_deploy:\n          name: conda_deploy_py38\n          requires:\n            - deploy\n          versions: \"3.8\"\n          filters:\n            tags:\n              only: /[0-9]+(\\.[0-9]+)*/\n            branches:\n              ignore: /.*/\n"
  },
  {
    "path": ".coveragerc",
    "content": "[run]\nomit =\n  env/*\n  venv/*\n  tests/*\n  setup.py\n  */tests/*.py\n\nsource =\n  .\n"
  },
  {
    "path": ".github/CODE_OF_CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to making participation in our project and\nour community a harassment-free experience for everyone, regardless of age, body\nsize, disability, ethnicity, sex characteristics, gender identity and expression,\nlevel of experience, education, socio-economic status, nationality, personal\nappearance, race, religion, or sexual identity and orientation.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy towards other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or advances\n* Trolling, insulting/derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or electronic address, without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a professional setting\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned to this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies within all project spaces, and it also applies when\nan individual is representing the project or its community in public spaces.\nExamples of representing a project or community include using an official\nproject e-mail address, posting via an official social media account, or acting\nas an appointed representative at an online or offline event. Representation of\na project may be further defined and clarified by project maintainers.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the project team at silvandeleemput@gmail.com. All\ncomplaints will be reviewed and investigated and will result in a response that\nis deemed necessary and appropriate to the circumstances. The project team is\nobligated to maintain confidentiality with regard to the reporter of an incident.\nFurther details of specific enforcement policies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,\navailable at <https://www.contributor-covenant.org/version/1/4/code-of-conduct.html>\n\nFor answers to common questions about this code of conduct, see\n<https://www.contributor-covenant.org/faq>\n\n[homepage]: https://www.contributor-covenant.org"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "# Contributing\n\nContributions are welcome, and they are greatly appreciated! \nEvery little bit helps, and credit will always be given.\n\nThe latest information about how to contribute to MemCNN can be found here:\n<https://memcnn.readthedocs.io/en/latest/contributing.html>\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE.md",
    "content": "* MemCNN version:\n* PyTorch version:\n* Python version:\n* Operating System:\n\n### Description\n\nDescribe what you were trying to get done.\nTell us what happened, what went wrong, and what you expected to happen.\n\n### What I Did\n\n```\nPaste the command(s) you ran and the output.\nIf there was a crash, please include the traceback here.\n```\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "<!-- Thank you for your contribution to MemCNN! Please replace {Please write here} with your description -->\n\n### What was the problem?\n\n{Please write here}\n\n### How this PR fixes the problem?\n\n{Please write here}\n\n### Check lists (check `x` in `[ ]` of list items)\n\n- [ ] Test passed\n- [ ] Coding style (indentation, etc)\n\n### Additional Comments (if any)\n\n{Please write here}"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# remove memcnn config\n/memcnn/config/config.json\n\n# PyCharm configs\n/.idea\n"
  },
  {
    "path": ".readthedocs.yml",
    "content": "# .readthedocs.yml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n  configuration: docs/conf.py\n\n# Build documentation with MkDocs\n#mkdocs:\n#  configuration: mkdocs.yml\n\n# Optionally build your docs in additional formats such as PDF and ePub\nformats:\n  - htmlzip\n  - epub\n\n# Optionally set the version of Python and requirements required to build your docs\npython:\n  version: 3.7\n  install:\n    - requirements: docsRequirements.txt"
  },
  {
    "path": "AUTHORS.rst",
    "content": "=======\nCredits\n=======\n\nDevelopment Lead\n----------------\n\n* Sil van de Leemput <silvandeleemput@gmail.com>\n\nContributors\n------------\n\n* Tycho van der Ouderaa\n* Jonas Teuwen\n* Bram van Ginneken\n* Rashindra Manniesing\n"
  },
  {
    "path": "CONTRIBUTING.rst",
    "content": ".. highlight:: shell\n\n============\nContributing\n============\n\nContributions are welcome, and they are greatly appreciated! Every little bit\nhelps, and credit will always be given.\n\nYou can contribute in many ways:\n\nTypes of Contributions\n----------------------\n\nReport Bugs\n~~~~~~~~~~~\n\nReport bugs at https://github.com/silvandeleemput/memcnn/issues.\n\nIf you are reporting a bug, please include:\n\n* Your operating system name and version.\n* Any details about your local setup that might be helpful in troubleshooting.\n* Detailed steps to reproduce the bug.\n\nFix Bugs\n~~~~~~~~\n\nLook through the GitHub issues for bugs. Anything tagged with \"bug\" and \"help\nwanted\" is open to whoever wants to implement it.\n\nImplement Features\n~~~~~~~~~~~~~~~~~~\n\nLook through the GitHub issues for features. Anything tagged with \"enhancement\"\nand \"help wanted\" is open to whoever wants to implement it.\n\nWrite Documentation\n~~~~~~~~~~~~~~~~~~~\n\nMemCNN could always use more documentation, whether as part of the\nofficial MemCNN docs, in docstrings, or even on the web in blog posts,\narticles, and such.\n\nSubmit Feedback\n~~~~~~~~~~~~~~~\n\nThe best way to send feedback is to file an issue at https://github.com/silvandeleemput/memcnn/issues.\n\nIf you are proposing a feature:\n\n* Explain in detail how it would work.\n* Keep the scope as narrow as possible, to make it easier to implement.\n* Remember that this is a volunteer-driven project, and that contributions\n  are welcome :)\n\nGet Started!\n------------\n\nReady to contribute? Here's how to set up `memcnn` for local development.\n\n1. Fork the `memcnn` repo on GitHub.\n2. Clone your fork locally::\n\n    $ git clone git@github.com:your_name_here/memcnn.git\n\n3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::\n\n    $ mkvirtualenv memcnn\n    $ cd memcnn/\n    $ python setup.py develop\n\n4. Create a branch for local development::\n\n    $ git checkout -b name-of-your-bugfix-or-feature\n\n   Now you can make your changes locally.\n\n5. When you're done making changes, check that your changes pass flake8 and the\n   tests, including testing other Python versions with tox::\n\n    $ flake8 memcnn tests\n    $ python setup.py test or py.test\n    $ tox\n\n   To get flake8 and tox, just pip install them into your virtualenv.\n\n6. Commit your changes and push your branch to GitHub::\n\n    $ git add .\n    $ git commit -m \"Your detailed description of your changes.\"\n    $ git push origin name-of-your-bugfix-or-feature\n\n7. Submit a pull request through the GitHub website.\n\nPull Request Guidelines\n-----------------------\n\nBefore you submit a pull request, check that it meets these guidelines:\n\n1. The pull request should include tests.\n2. If the pull request adds functionality, the docs should be updated. Put\n   your new functionality into a function with a docstring, and add the\n   feature to the list in README.rst.\n3. The pull request should work for Python 2.7, 3.5+, and for PyPy. Check\n   through tox that all the tests pass for all supported Python versions.\n\nTips\n----\n\nTo run a subset of tests::\n\n$ pytest memcnn/memcnn/models/tests/test_revop.py\n\nTo run a specific test::\n\n$ pytest memcnn/memcnn/models/tests/test_revop.py::test_reversible_block_fwd_bwd\n\n\nDeploying\n---------\n\nA reminder for the maintainers on how to deploy.\nMake sure all your changes are committed (including an entry in HISTORY.rst).\nThen run::\n\n$ bumpversion patch # possible: major / minor / patch\n$ git push\n$ git push origin <tag_name>\n\nCircleCI will then deploy to PyPI if tests pass.\n"
  },
  {
    "path": "HISTORY.rst",
    "content": "=======\nHistory\n=======\n\n1.5.2 (2023-05-10)\n------------------\n* Fixed issue with CIFAR data loaders not being able to be pickled because of local Lambda operations\n* Fixed CI issues, disabled PyTorch v1.0, v1.1, and latest checks\n\n1.5.1 (2021-08-07)\n------------------\n* Added support for 2-dimensional inputs for AffineAdapterSigmoid\n* Fixed CI issues\n\n1.5.0 (2020-11-24)\n------------------\n* Added support for mixed-precision training using torch.cuda.amp (inputs fixed to float32 for now)\n* Added support for PyTorch v1.7\n* Dropped support for PyTorch < v1.0 and Python 2\n* Removed the version limit for Pillow in the requirements\n\n1.4.0 (2020-06-05)\n------------------\n* Added support for splitting on arbitrary dimensions to the Couplings. Big thanks to ClashLuke for the PR\n* Added a preserve_rng_state option to the InvertibleModuleWrapper\n\n1.3.2 (2020-03-05)\n------------------\n* Improved InvertibleModuleWrapper\n  * Added support for multi input/output invertible operations! Big thanks to Christian Etmann for the PR\n* Improved the is_invertible_module test\n  * Added multi input/output checks\n  * Fixed random seed per default\n  * Additional warning checks have been added\n\n1.3.1 (2020-03-02)\n------------------\n* HOTFIX InvertibleCheckpointFunction uses ref_count for inputs as well to avoid memory spikes\n\n1.3.0 (2020-03-01)\n------------------\n\n* Updated underlying mechanics for the InvertibleModuleWrapper\n  * Hooks have been replaced by a torch.autograd.Function called InvertibleCheckpointFunction\n  * Identity functions are now supported\n* Reported unstable memory behavior should be fixed now when using the InvertibleModuleWrapper!\n* Minor changes to test suite\n\n1.2.1 (2020-02-24)\n------------------\n\n* Added InvertibleModuleWrapper support to is_invertible_module test\n\n1.2.0 (2020-01-19)\n------------------\n\n* Replaced TensorBoard logging with simple json file logging which removed the cumbersome TensorBoard and TensorFlow dependencies\n* Updated the Dockerfile for Python37 and PyTorch 1.4.0\n* Updated the CI tests Py36 versions to Py37, also added a new CI test for PyTorch 1.4.0\n\n1.1.1 (2020-01-11)\n------------------\n\n* Fixed some versions in the requirements for TensorFlow and Pillow to avoid errors and segfaults\n* The module auto documentation has been updated for the new API changes\n\n1.1.0 (2019-12-15)\n------------------\n\n* A complete refactor of MemCNN with changes to the API\n* Factored out the code responsible for the memory savings in a separate InvertibleModuleWrapper and reimplemented it using hooks\n* The InvertibleModuleWrapper allows for arbitrary invertible functions now (not just the additive and affine couplings)\n* The AdditiveBlock and AffineBlock have been refactored to AdditiveCoupling and AffineCoupling\n* The ReveribleBlock is now deprecated\n* The documentation and examples have been updated for the new API changes\n\n1.0.1 (2019-12-08)\n------------------\n\n* Bug fixes related to SummaryIterator import in Tensorflow 2\n  (location of summary_iterator has changed in TensorFlow)\n* Bug fixes related to NSamplesRandomSampler nsamples attribute\n  (would crash if no-gpu and numpy.int were given)\n\n\n1.0.0 (2019-07-28)\n------------------\n\n* Major release for completing the JOSS review:\n* Anaconda cloud and codacy code quality CI\n* Updated/improved documentation\n\n0.3.5 (2019-07-28)\n------------------\n\n* Added CI for anaconda cloud\n* Documented conda installation steps\n* Minor test release for testing CI build\n\n0.3.4 (2019-07-26)\n------------------\n\n* Performed changes recommended by JOSS reviewers:\n* Added requirements.txt to manifest.in\n* Added codacy code quality integration\n* Improved documentation\n* Setup proper github contribution templates\n\n0.3.3 (2019-07-10)\n------------------\n\n* Added docker build triggers to CI\n* Finalized JOSS paper.md\n\n0.3.2 (2019-07-10)\n------------------\n\n* Added docker build shield\n* Fixed a bug with device agnostic tensor generation for loss.py\n* Code cleanup resnet.py\n* Added examples to distribution with pytests\n* Improved documentation\n\n0.3.1 (2019-07-09)\n------------------\n\n* Added experiments.json and config.json.example data files to the distribution\n* Fixed documentation issues with mock modules\n\n0.3.0 (2019-07-09)\n------------------\n\n* Updated major bug in distribution setup.py\n* Removed older releases due to bug\n* Added the ReversibleBlock at the module level\n* Splitted keep_input into keep_input and keep_input_inverse\n\n0.2.1 (2019-06-06 - Removed)\n----------------------------\n\n* Patched the memory saving tests\n\n0.2.0 (2019-05-28 - Removed)\n----------------------------\n\n* Minor update with better coverage and affine coupling support\n\n0.1.0 (2019-05-24 - Removed)\n----------------------------\n\n* First release on PyPI\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT License\n\nCopyright (c) 2018 Sil C. van de Leemput\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include AUTHORS.rst\ninclude CONTRIBUTING.rst\ninclude HISTORY.rst\ninclude LICENSE.txt\ninclude README.rst\ninclude requirements.txt\ninclude devRequirements.txt\ninclude docsRequirements.txt\n\nrecursive-include tests *\nrecursive-exclude * __pycache__\nrecursive-exclude * *.py[co]\n\nrecursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif\nrecursive-include memcnn/config config.json.example experiments.json\n"
  },
  {
    "path": "README.rst",
    "content": "======\nMemCNN\n======\n\n.. image:: https://img.shields.io/badge/maintenance-unmaintained-red.svg\n        :alt: Unmaintained!\n        :target: https://github.com/silvandeleemput/memcnn\n\n.. image:: https://img.shields.io/circleci/build/github/silvandeleemput/memcnn/master.svg        \n        :alt: CircleCI - Status master branch\n        :target: https://circleci.com/gh/silvandeleemput/memcnn/tree/master\n\n.. image:: https://readthedocs.org/projects/memcnn/badge/?version=latest        \n        :alt: Documentation - Status master branch\n        :target: https://memcnn.readthedocs.io/en/latest/?badge=latest\n\n.. image:: https://img.shields.io/codacy/grade/95de32e0d7c54d038611da47e9f0948b/master.svg\n        :alt: Codacy - Branch grade\n        :target: https://app.codacy.com/project/silvandeleemput/memcnn/dashboardgit\n\n.. image:: https://img.shields.io/codecov/c/gh/silvandeleemput/memcnn/master.svg   \n        :alt: Codecov - Status master branch\n        :target: https://codecov.io/gh/silvandeleemput/memcnn\n\n.. image:: https://img.shields.io/pypi/v/memcnn.svg\n        :alt: PyPI - Latest release\n        :target: https://pypi.python.org/pypi/memcnn\n\n.. image:: https://img.shields.io/conda/vn/silvandeleemput/memcnn?label=anaconda\n        :alt: Conda - Latest release\n        :target: https://anaconda.org/silvandeleemput/memcnn\n\n.. image:: https://img.shields.io/pypi/implementation/memcnn.svg        \n        :alt: PyPI - Implementation\n        :target: https://pypi.python.org/pypi/memcnn\n\n.. image:: https://img.shields.io/pypi/pyversions/memcnn.svg        \n        :alt: PyPI - Python version\n        :target: https://pypi.python.org/pypi/memcnn\n\n.. image:: https://img.shields.io/github/license/silvandeleemput/memcnn.svg        \n        :alt: GitHub - Repository license\n        :target: https://github.com/silvandeleemput/memcnn/blob/master/LICENSE.txt\n\n.. image:: http://joss.theoj.org/papers/10.21105/joss.01576/status.svg\n        :alt: JOSS - DOI\n        :target: https://doi.org/10.21105/joss.01576\n\nA `PyTorch <http://pytorch.org/>`__ framework for developing memory-efficient invertible neural networks.\n\n* Free software: `MIT license <https://github.com/silvandeleemput/memcnn/blob/master/LICENSE.txt>`__ (please cite our work if you use it)\n* Documentation: https://memcnn.readthedocs.io.\n* Installation: https://memcnn.readthedocs.io/en/latest/installation.html\n\n⚠️ Project Status: Unmaintained\n\nThis repository is no longer actively maintained.\n\nThe code is kept available for reference and historical purposes, but no new features, bug fixes, or support should be expected.\n\nIf you find the project useful, feel free to fork it and continue development.\n\nFeatures\n--------\n\n* Enable memory savings during training by wrapping arbitrary invertible PyTorch functions with the `InvertibleModuleWrapper` class.\n* Simple toggling of memory saving by setting the `keep_input` property of the `InvertibleModuleWrapper`.\n* Turn arbitrary non-linear PyTorch functions into invertible versions using the `AdditiveCoupling` or the `AffineCoupling` classes.\n* Training and evaluation code for reproducing RevNet experiments using MemCNN.\n* CI tests for Python v3.7 and torch v1.0, v1.1, v1.4 and v1.7 with good code coverage.\n\nExamples\n--------\n\nCreating an AdditiveCoupling with memory savings\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code:: python\n\n    import torch\n    import torch.nn as nn\n    import memcnn\n\n\n    # define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d\n    class ExampleOperation(nn.Module):\n        def __init__(self, channels):\n            super(ExampleOperation, self).__init__()\n            self.seq = nn.Sequential(\n                                        nn.Conv2d(in_channels=channels, out_channels=channels,\n                                                  kernel_size=(3, 3), padding=1),\n                                        nn.BatchNorm2d(num_features=channels),\n                                        nn.ReLU(inplace=True)\n                                    )\n\n        def forward(self, x):\n            return self.seq(x)\n\n\n    # generate some random input data (batch_size, num_channels, y_elements, x_elements)\n    X = torch.rand(2, 10, 8, 8)\n\n    # application of the operation(s) the normal way\n    model_normal = ExampleOperation(channels=10)\n    model_normal.eval()\n\n    Y = model_normal(X)\n\n    # turn the ExampleOperation invertible using an additive coupling\n    invertible_module = memcnn.AdditiveCoupling(\n        Fm=ExampleOperation(channels=10 // 2),\n        Gm=ExampleOperation(channels=10 // 2)\n    )\n\n    # test that it is actually a valid invertible module (has a valid inverse method)\n    assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)\n\n    # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training\n    invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)\n\n    # by default the module is set to training, the following sets this to evaluation\n    # note that this is required to pass input tensors to the model with requires_grad=False (inference only)\n    invertible_module_wrapper.eval()\n\n    # test that the wrapped module is also a valid invertible module\n    assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)\n\n    # compute the forward pass using the wrapper\n    Y2 = invertible_module_wrapper.forward(X)\n\n    # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2\n    X2 = invertible_module_wrapper.inverse(Y2)\n\n    # test that the input and approximation are similar\n    assert torch.allclose(X, X2, atol=1e-06)\n\nRun PyTorch Experiments\n-----------------------\n\nAfter installing MemCNN run:\n\n.. code:: bash\n\n    python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]\n\n* Available values for ``DATASET`` are ``cifar10`` and ``cifar100``.\n* Available values for ``MODEL`` are ``resnet32``, ``resnet110``, ``resnet164``, ``revnet38``, ``revnet110``, ``revnet164``\n* Use the ``--fresh`` flag to remove earlier experiment results.\n* Use the ``--no-cuda`` flag to train on the CPU rather than the GPU through CUDA.\n\nDatasets are automatically downloaded if they are not available.\n\nWhen using Python 3.* replace the ``python`` directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use ``python3.6``.\n\nWhen MemCNN was installed using `pip` or from sources you might need to setup a configuration file before running this command.\nRead the corresponding section about how to do this here: https://memcnn.readthedocs.io/en/latest/installation.html\n\nResults\n-------\n\nTensorFlow results were obtained from `the reversible residual\nnetwork <https://arxiv.org/abs/1707.04585>`__ running the code from\ntheir `GitHub <https://github.com/renmengye/revnet-public>`__.\n\nThe PyTorch results listed were recomputed on June 11th 2018, and differ\nfrom the results in the ICLR paper. The Tensorflow results are still the\nsame.\n\nPrediction accuracy\n^^^^^^^^^^^^^^^^^^^\n\n+------------+------------------------+--------------------------+----------------------+----------------------+\n|            |               Cifar-10                            |               Cifar-100                     |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| Model      |    Tensorflow          |      PyTorch             |      Tensorflow      |     PyTorch          |\n+============+========================+==========================+======================+======================+\n| resnet-32  |  92.74                 |    92.86                 |   69.10              |  69.81               |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| resnet-110 |  93.99                 |    93.55                 |   73.30              |  72.40               |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| resnet-164 |  94.57                 |    94.80                 |   76.79              |  76.47               |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| revnet-38  |  93.14                 |    92.80                 |   71.17              |  69.90               |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| revnet-110 |  94.02                 |    94.10                 |   74.00              |  73.30               |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| revnet-164 |  94.56                 |    94.90                 |   76.39              |  76.90               |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n\nTraining time (hours : minutes)\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+------------+------------------------+--------------------------+----------------------+----------------------+\n|            |               Cifar-10                            |               Cifar-100                     |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| Model      |    Tensorflow          |      PyTorch             |      Tensorflow      |     PyTorch          |\n+============+========================+==========================+======================+======================+\n| resnet-32  |             2:04       |    1:51                  |       1:58           |              1:51    |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| resnet-110 |             4:11       |    2:51                  |       6:44           |              2:39    |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| resnet-164 |            11:05       |    4:59                  |   10:59              |              3:45    |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| revnet-38  |             2:17       |    2:09                  |       2:20           |              2:16    |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| revnet-110 |             6:59       |    3:42                  |       7:03           |              3:50    |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n| revnet-164 |            13:09       |    7:21                  |   13:12              |              7:17    |\n+------------+------------------------+--------------------------+----------------------+----------------------+\n\nMemory consumption of model training in PyTorch\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n+------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+\n|               Layers                              |               Parameters                    |               Parameters (MB)                     |               Activations (MB)              |\n+------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+\n|    ResNet              |      RevNet              |    ResNet            |      RevNet          |    ResNet              |      RevNet              |    ResNet            |      RevNet          |\n+========================+==========================+======================+======================+========================+==========================+======================+======================+\n|               32       |    38                    |       466906         |          573994      |             1.9        |    2.3                   |       238.6          |              85.6    |\n+------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+\n|              110       |    110                   |       1730714        |           1854890    |             6.8        |    7.3                   |       810.7          |              85.7    |\n+------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+\n|              164       |    164                   |   1704154            |         1983786      |            6.8         |    7.9                   |   2452.8             |             432.7    |\n+------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+\n\nThe `ResNet` model is the conventional Residual Network implementation in PyTorch, while\nthe RevNet model uses the `memcnn.InvertibleModuleWrapper` to achieve memory savings.\n\nWorks using MemCNN\n------------------\n\n* `MemCNN: a Framework for Developing Memory Efficient Deep Invertible Networks <https://openreview.net/forum?id=r1KzqK1wz>`__ by Sil C. van de Leemput et al.\n* `Reversible GANs for Memory-efficient Image-to-Image Translation <https://arxiv.org/abs/1902.02729>`__ by Tycho van der Ouderaa et al.\n* `Chest CT Super-resolution and Domain-adaptation using Memory-efficient 3D Reversible GANs <https://openreview.net/forum?id=SkxueFsiFV>`__ by Tycho van der Ouderaa et al.\n* `iUNets: Fully invertible U-Nets with Learnable Up- and Downsampling <https://arxiv.org/abs/2005.05220>`__ by Christian Etmann et al.\n\nCitation\n--------\n\nSil C. van de Leemput, Jonas Teuwen, Bram van Ginneken, and Rashindra Manniesing.\nMemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks.\nJournal of Open Source Software, 4, 1576, http://dx.doi.org/10.21105/joss.01576, 2019.\n\nIf you use our code, please cite:\n\n.. code:: bibtex\n\n    @article{vandeLeemput2019MemCNN,\n      journal = {Journal of Open Source Software},\n      doi = {10.21105/joss.01576},\n      issn = {2475-9066},\n      number = {39},\n      publisher = {The Open Journal},\n      title = {MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks},\n      url = {http://dx.doi.org/10.21105/joss.01576},\n      volume = {4},\n      author = {Sil C. {van de} Leemput and Jonas Teuwen and Bram {van} Ginneken and Rashindra Manniesing},\n      pages = {1576},\n      date = {2019-07-30},\n      year = {2019},\n      month = {7},\n      day = {30},\n    }\n"
  },
  {
    "path": "bandit.yml",
    "content": "skips: ['B101']\n"
  },
  {
    "path": "devRequirements.txt",
    "content": "-r requirements.txt\nbumpversion\nwheel\nwatchdog\nflake8\ntox\ncoverage\nSphinx\ntwine\npytest\npytest-cov\npytest-runner\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04\n\nRUN apt-get update && apt-get install -y \\\n  software-properties-common \\\n  && \\\n  rm -rf /var/lib/apt/lists/*\nRUN add-apt-repository ppa:deadsnakes/ppa && apt-get update\nRUN apt-get install -y \\\n  git \\\n  python3.7-dev \\\n  python3-pip \\\n  sudo \\\n  && rm -rf /var/lib/apt/lists/*\n\n# Add user with valid passwrd\nRUN useradd -ms /bin/bash user\nRUN (echo user ; echo user) | passwd user\n\n# Configure sudo\nRUN usermod -a -G sudo user\n\n# Install necessary python libraries\nRUN python3.7 -m pip install pip --upgrade\nRUN python3.7 -m pip install pip install torch===1.7.0 torchvision===0.8.1 torchaudio===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html\nRUN python3.7 -m pip install memcnn\nRUN python3.7 -m pip install pytest\n\n# Set MemCNN config file for user environement\nRUN python3.7 -c \"import os, shutil, memcnn; path=os.path.join(os.path.dirname(memcnn.__file__), 'config'); shutil.copy(os.path.join(path, 'config.json.example'), os.path.join(path, 'config.json'));\"\n\n# Change user and prepare user data folders\nUSER user\nWORKDIR /home/user\nRUN mkdir data\nRUN mkdir experiments\n\nENTRYPOINT /bin/bash\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = python -msphinx\nSPHINXPROJ    = memcnn\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/authors.rst",
    "content": ".. include:: ../AUTHORS.rst\n"
  },
  {
    "path": "docs/conf.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# memcnn documentation build configuration file, created by\n# sphinx-quickstart on Fri Jun  9 13:47:02 2017.\n#\n# This file is execfile()d with the current directory set to its\n# containing dir.\n#\n# Note that not all possible configuration values are present in this\n# autogenerated file.\n#\n# All configuration values have a default; values that are commented out\n# serve to show the default.\n\n# If extensions (or modules to document with autodoc) are in another\n# directory, add these directories to sys.path here. If the directory is\n# relative to the documentation root, use os.path.abspath to make it\n# absolute, like shown here.\n#\nimport os\nimport sys\n\nsys.path.insert(0, os.path.abspath('..'))\n\n# this gets the memcnn_version without importing memcnn and causing troubles with mock later on\nwith open(os.path.join(os.path.dirname(__file__), '..', 'memcnn', '__init__.py'), 'r') as f:\n    memcnn_version = [line.split(\"'\")[1] for line in f.readlines() if '__version__' in line][0]\n\n# -- General configuration ---------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.\nextensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.mathjax',\n              'sphinx.ext.napoleon', 'sphinx.ext.intersphinx']\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = '.rst'\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# Napoleon settings\nnapoleon_google_docstring = False\nnapoleon_numpy_docstring = True\n\n# autodoc settings\nautoclass_content = 'both'\nautodoc_mock_imports = ['torch', 'torch.nn', 'numpy', 'torchvision']\n\n\nintersphinx_mapping = {\n    'python': ('https://docs.python.org/', None),\n    'numpy': ('http://docs.scipy.org/doc/numpy/', None),\n    'torch': ('https://pytorch.org/docs/stable/', None)\n}\n\nmathjax_path = \"https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML\"\nmathjax_config = {\n    'extensions': ['tex2jax.js'],\n    'jax': ['input/TeX', 'output/HTML-CSS'],\n}\n\n# General information about the project.\nproject = u'MemCNN'\ncopyright = u\"2019, Sil van de Leemput\"\nauthor = u\"Sil van de Leemput\"\n\n\n# The version info for the project you're documenting, acts as replacement\n# for |version| and |release|, also used in various other places throughout\n# the built documents.\n#\n# The short X.Y version.\nversion = memcnn_version\n# The full version, including alpha/beta/rc tags.\nrelease = memcnn_version\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = None\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This patterns also effect to html_static_path and html_extra_path\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'sphinx'\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = False\n\n\n# -- Options for HTML output -------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_rtd_theme'\n\n# Theme options are theme-specific and customize the look and feel of a\n# theme further.  For a list of options available for each theme, see the\n# documentation.\n#\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = []\n\n\n# -- Options for HTMLHelp output ---------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'memcnndoc'\n\n\n# -- Options for LaTeX output ------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title, author, documentclass\n# [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, 'memcnn.tex',\n     u'MemCNN Documentation',\n     u'Sil van de Leemput', 'manual'),\n]\n\n\n# -- Options for manual page output ------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [\n    (master_doc, 'memcnn',\n     u'MemCNN Documentation',\n     [author], 1)\n]\n\n\n# -- Options for Texinfo output ----------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (master_doc, 'memcnn',\n     u'MemCNN Documentation',\n     author,\n     'memcnn',\n     'A PyTorch framework for developing memory efficient deep invertible networks.',\n     'Miscellaneous'),\n]\n"
  },
  {
    "path": "docs/contributing.rst",
    "content": ".. include:: ../CONTRIBUTING.rst\n"
  },
  {
    "path": "docs/history.rst",
    "content": ".. include:: ../HISTORY.rst\n"
  },
  {
    "path": "docs/index.rst",
    "content": "Welcome to MemCNN's documentation!\n======================================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Contents:\n\n   readme\n   installation\n   usage\n   modules\n   contributing\n   authors\n   history\n\nIndices and tables\n==================\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/installation.rst",
    "content": ".. highlight:: shell\n\n============\nInstallation\n============\n\nRequirements\n------------\n\n-  `Python <https://python.org/>`__ 3.6+\n-  `PyTorch <http://pytorch.org/>`__ 1.0+ (CUDA support recommended)\n\n\nStable release\n--------------\n\nThese are the preferred methods to install MemCNN, as they will always install the most recent stable release.\n\nPyPi\n^^^^\n\nTo install MemCNN using the Python package manager, run this command in your terminal:\n\n.. code-block:: console\n\n    $ pip install memcnn\n\nIf you don't have `pip`_ installed, this `Python installation guide`_ can guide\nyou through the process.\n\n.. _pip: https://pip.pypa.io\n.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/\n\nAnaconda\n^^^^^^^^\n\nTo install MemCNN using Anaconda, run this command in your terminal:\n\n.. code-block:: console\n\n    $ conda install -c silvandeleemput -c pytorch -c simpleitk -c conda-forge memcnn\n\nIf you don't have `conda`_ installed, this `Anaconda installation guide`_ can guide\nyou through the process.\n\n.. _conda: https://www.anaconda.com/\n.. _Anaconda installation guide: https://docs.conda.io/projects/conda/en/latest/user-guide/install/\n\nFrom sources\n------------\n\nThe sources for MemCNN can be downloaded from the `Github repo`_.\n\nYou can either clone the public repository:\n\n.. code-block:: console\n\n    $ git clone git://github.com/silvandeleemput/memcnn\n\nOr download the `tarball`_:\n\n.. code-block:: console\n\n    $ curl  -OL https://github.com/silvandeleemput/memcnn/tarball/master\n\nOnce you have a copy of the source, you can install it with:\n\n.. code-block:: console\n\n    $ python setup.py install\n\n\n.. _Github repo: https://github.com/silvandeleemput/memcnn\n.. _tarball: https://github.com/silvandeleemput/memcnn/tarball/master\n\n\nUsing docker\n------------\n\nMemCNN has several pre-build docker images that are hosted on dockerhub.\nYou can directly pull these and to have a working environment for running the experiments.\n\nRun image from repository\n^^^^^^^^^^^^^^^^^^^^^^^^^\n\nRun the latest docker build of MemCNN from the repository (automatically pulls the image):\n\n.. code-block:: console\n\n    $ docker run --shm-size=4g --runtime=nvidia -it silvandeleemput/memcnn:latest\n\nFor ``--runtime=nvidia`` to work `nvidia-docker <https://github.com/nvidia/nvidia-docker>`__ must be installed on your system.\nIt can be omitted but this will drop GPU training support.\n\nThis will open a preconfigured bash shell, which is correctly configured\nto run the experiments. The latest version has Ubuntu 18.04 and Python 3.7 installed.\n\nBy default, the datasets and experimental results will be put inside the created\ndocker container under: ``\\home\\user\\data`` and\n``\\home\\user\\experiments`` respectively.\n\nBuild image from source\n^^^^^^^^^^^^^^^^^^^^^^^\n\nRequirements:\n\n-  NVIDIA graphics card and the proper NVIDIA-drivers on your system\n\n\nThe following bash commands will clone this repository and do a one-time\nbuild of the docker image with the right environment installed:\n\n.. code-block:: console\n\n    $ git clone https://github.com/silvandeleemput/memcnn.git\n    $ docker build ./memcnn/docker --tag=silvandeleemput/memcnn:latest\n\nAfter the one-time install on your machine, the docker image can be invoked\nusing the same commands as listed above.\n\nExperiment configuration file\n-----------------------------\n\nTo run the experiments, MemCNN requires setting up a configuration file containing locations to put the data files.\nThis step is not necessary for the docker builds.\n\n\nThe configuration file ``config.json`` goes in the ``/memcnn/config/`` directory of the library and should be formatted as follows:\n\n.. code:: json\n\n    {\n        \"data_dir\": \"/home/user/data\",\n        \"results_dir\": \"/home/user/experiments\"\n    }\n\n* data_dir    : location for storing the input training datasets\n* results_dir : location for storing the experiment files during training\n\nChange the data paths to your liking.\n\nIf you are unsure where MemCNN and/or the configuration file is located on your machine run:\n\n.. code-block:: console\n\n    $ python -m memcnn.train\n\nIf the configuration file is not setup correctly, this command should give the user the correct path to the configuration file.\nNext, create/edit the file at the given location.\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=python -msphinx\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\nset SPHINXPROJ=memcnn\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The Sphinx module was not found. Make sure you have Sphinx installed,\n\techo.then set the SPHINXBUILD environment variable to point to the full\n\techo.path of the 'sphinx-build' executable. Alternatively you may add the\n\techo.Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\n\n:end\npopd\n"
  },
  {
    "path": "docs/modules.rst",
    "content": "=======\nModules\n=======\n\n.. automodule:: memcnn\n  :members: is_invertible_module\n\n.. autoclass:: memcnn.InvertibleModuleWrapper\n  :members: forward, inverse\n\n.. autoclass:: memcnn.AdditiveCoupling\n  :members: forward, inverse\n\n.. autoclass:: memcnn.AffineCoupling\n  :members: forward, inverse\n\n.. autoclass:: memcnn.AffineAdapterNaive\n\n.. autoclass:: memcnn.AffineAdapterSigmoid\n\n.. autoclass:: memcnn.ReversibleBlock\n  :members: forward, inverse\n"
  },
  {
    "path": "docs/readme.rst",
    "content": ".. include:: ../README.rst\n"
  },
  {
    "path": "docs/usage.rst",
    "content": "=====\nUsage\n=====\n\nTo use MemCNN in a project::\n\n    import memcnn\n\n\nExamples\n--------\n\nCreating an AdditiveCoupling with memory savings\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. code:: python\n\n    import torch\n    import torch.nn as nn\n    import memcnn\n\n\n    # define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d\n    class ExampleOperation(nn.Module):\n        def __init__(self, channels):\n            super(ExampleOperation, self).__init__()\n            self.seq = nn.Sequential(\n                                        nn.Conv2d(in_channels=channels, out_channels=channels,\n                                                  kernel_size=(3, 3), padding=1),\n                                        nn.BatchNorm2d(num_features=channels),\n                                        nn.ReLU(inplace=True)\n                                    )\n\n        def forward(self, x):\n            return self.seq(x)\n\n\n    # generate some random input data (batch_size, num_channels, y_elements, x_elements)\n    X = torch.rand(2, 10, 8, 8)\n\n    # application of the operation(s) the normal way\n    model_normal = ExampleOperation(channels=10)\n    model_normal.eval()\n\n    Y = model_normal(X)\n\n    # turn the ExampleOperation invertible using an additive coupling\n    invertible_module = memcnn.AdditiveCoupling(\n        Fm=ExampleOperation(channels=10 // 2),\n        Gm=ExampleOperation(channels=10 // 2)\n    )\n\n    # test that it is actually a valid invertible module (has a valid inverse method)\n    assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)\n\n    # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training\n    invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)\n\n    # by default the module is set to training, the following sets this to evaluation\n    # note that this is required to pass input tensors to the model with requires_grad=False (inference only)\n    invertible_module_wrapper.eval()\n\n    # test that the wrapped module is also a valid invertible module\n    assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)\n\n    # compute the forward pass using the wrapper\n    Y2 = invertible_module_wrapper.forward(X)\n\n    # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2\n    X2 = invertible_module_wrapper.inverse(Y2)\n\n    # test that the input and approximation are similar\n    assert torch.allclose(X, X2, atol=1e-06)\n\nRun PyTorch Experiments\n-----------------------\n\n.. include:: ./usage_experiments.rst\n"
  },
  {
    "path": "docs/usage_experiments.rst",
    "content": "After installing MemCNN run:\n\n.. code:: bash\n\n    python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda]\n\n* Available values for ``DATASET`` are ``cifar10`` and ``cifar100``.\n* Available values for ``MODEL`` are ``resnet32``, ``resnet110``, ``resnet164``, ``revnet38``, ``revnet110``, ``revnet164``\n* Use the ``--fresh`` flag to remove earlier experiment results.\n* Use the ``--no-cuda`` flag to train on the CPU rather than the GPU through CUDA.\n\nDatasets are automatically downloaded if they are not available.\n\nWhen using Python 3.* replace the ``python`` directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use ``python3.7``.\n"
  },
  {
    "path": "docsRequirements.txt",
    "content": "sphinx\nsphinxcontrib-plantuml\nsphinxcontrib-ansibleautodoc\nsphinx_rtd_theme\nPyYAML\nmock\n"
  },
  {
    "path": "memcnn/.editorconfig",
    "content": "# http://editorconfig.org\n\nroot = true\n\n[*]\nindent_style = space\nindent_size = 4\ntrim_trailing_whitespace = true\ninsert_final_newline = true\ncharset = utf-8\nend_of_line = lf\n\n[*.bat]\nindent_style = tab\nend_of_line = crlf\n\n[LICENSE]\ninsert_final_newline = false\n\n[Makefile]\nindent_style = tab\n"
  },
  {
    "path": "memcnn/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Top-level package for MemCNN.\"\"\"\n\n__author__ = \"\"\"Sil van de Leemput\"\"\"\n__email__ = 'silvandeleemput@gmail.com'\n__version__ = '1.5.2'\n\n\nfrom memcnn.models.revop import ReversibleBlock, InvertibleModuleWrapper, create_coupling, is_invertible_module\nfrom memcnn.models.additive import AdditiveCoupling\nfrom memcnn.models.affine import AffineCoupling, AffineAdapterNaive, AffineAdapterSigmoid\n\n__all__ = [\n    'AdditiveCoupling',\n    'AffineCoupling',\n    'AffineAdapterNaive',\n    'AffineAdapterSigmoid',\n    'InvertibleModuleWrapper',\n    'ReversibleBlock',\n    'create_coupling',\n    'is_invertible_module'\n]\n"
  },
  {
    "path": "memcnn/config/__init__.py",
    "content": "import json\nimport os\n\n\nclass Config(dict):\n    def __init__(self, dic=None, verbose=False):\n        super(Config, self).__init__()\n        if dic is None:\n            fname = self.get_filename()\n            if verbose:\n                print(\"loading default {0}\".format(fname))\n            with open(fname, \"r\") as f:\n                dic = json.load(f)\n        self.update(dic)\n\n    @staticmethod\n    def get_filename():\n        return os.path.join(Config.get_dir(), \"config.json\")\n\n    @staticmethod\n    def get_dir():\n        return os.path.dirname(__file__)\n"
  },
  {
    "path": "memcnn/config/config.json.example",
    "content": "{\n  \"data_dir\": \"/home/user/data\",\n  \"results_dir\": \"/home/user/experiments\"\n}"
  },
  {
    "path": "memcnn/config/experiments.json",
    "content": "{\n    \"resnet32\":\n    {\n        \"data_loader_params\": {\n            \"batch_size\": 100,\n            \"max_epoch\": 80000\n        },\n        \"model\": \"memcnn.models.resnet.ResNet\",\n        \"model_params\": {\n            \"block\":\"memcnn.models.resnet.BasicBlock\",\n            \"layers\":[5, 5, 5],\n            \"channels_per_layer\":[16,16,32,64],\n            \"strides\":[1, 1, 2, 2],\n            \"init_max_pool\":false,\n            \"init_kernel_size\":3,\n            \"batch_norm_fix\":false\n        },\n        \"optimizer\": \"torch.optim.SGD\",\n        \"optimizer_params\": {\n            \"lr\":0.1,\n            \"momentum\":0.9,\n            \"weight_decay\":2e-4\n        },\n        \"trainer\":\"memcnn.trainers.classification.train\",\n        \"trainer_params\":{\n            \"loss\":\"memcnn.utils.loss.CrossEntropyLossTF\"\n        }\n    },\n\n    \"resnet110\":\n    {\n        \"base\": \"resnet32\",\n        \"model_params\": {\n            \"layers\":[18, 18, 18]\n        }\n    },\n\n    \"resnet164\":\n    {\n        \"base\": \"resnet110\",\n        \"model_params\": {\n            \"block\":\"memcnn.models.resnet.Bottleneck\"\n        }\n    },\n\n    \"revnet38\":\n    {\n        \"base\": \"resnet32\",\n        \"model_params\": {\n            \"layers\":[3, 3, 3],\n            \"channels_per_layer\":[32,32,64,112],\n            \"block\":\"memcnn.models.resnet.RevBasicBlock\"\n        }\n    },\n\n    \"revnet110\":\n    {\n        \"base\": \"revnet38\",\n        \"model_params\": {\n            \"layers\":[9, 9, 9],\n            \"channels_per_layer\":[32,32,64,128]\n        }\n    },\n\n    \"revnet164\":\n    {\n        \"base\": \"revnet110\",\n        \"model_params\": {\n            \"block\":\"memcnn.models.resnet.RevBottleneck\"\n        }\n    },\n\n    \"cifar10\":\n    {\n        \"data_loader\": \"memcnn.data.cifar.get_cifar_data_loaders\",\n        \"data_loader_params\": {\n            \"dataset\": \"torchvision.datasets.CIFAR10\",\n            \"workers\": 16\n        },\n        \"model_params\": {\n            \"num_classes\":10\n        }\n    },\n\n    \"cifar100\":\n    {\n        \"data_loader\": \"memcnn.data.cifar.get_cifar_data_loaders\",\n        \"data_loader_params\": {\n            \"dataset\": \"torchvision.datasets.CIFAR100\",\n            \"workers\": 16\n        },\n        \"model_params\": {\n            \"num_classes\":100\n        }\n    }\n}"
  },
  {
    "path": "memcnn/config/tests/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/config/tests/test_config.py",
    "content": "import unittest\nimport json\nimport os\nfrom memcnn.experiment.factory import load_experiment_config, experiment_config_parser\nfrom memcnn.config import Config\nimport memcnn.config\n\n\nclass ConfigTestCase(unittest.TestCase):\n\n    class ConfigTest(Config):\n        @staticmethod\n        def get_filename():\n            return os.path.join(Config.get_dir(), \"config.json.example\")\n\n    def setUp(self):\n        self.config = ConfigTestCase.ConfigTest()\n\n        self.config_fname = os.path.join(os.path.dirname(__file__), \"..\", \"config.json.example\")\n        self.experiments_fname = os.path.join(os.path.dirname(__file__), \"..\", \"experiments.json\")\n\n        def load_json_file(fname):\n            with open(fname, 'r') as f:\n                data = json.load(f)\n            return data\n\n        self.load_json_file = load_json_file\n\n    def test_loading_main_config(self):\n        self.assertTrue(os.path.exists(self.config.get_filename()))\n        data = self.config\n        self.assertTrue(isinstance(data, dict))\n        self.assertTrue(\"data_dir\" in data)\n        self.assertTrue(\"results_dir\" in data)\n\n    def test_loading_experiments_config(self):\n        self.assertTrue(os.path.exists(self.experiments_fname))\n        data = self.load_json_file(self.experiments_fname)\n        self.assertTrue(isinstance(data, dict))\n\n    def test_experiment_configs(self):\n        data = self.load_json_file(self.experiments_fname)\n        config = self.config\n        keys = data.keys()\n        for key in keys:\n            result = load_experiment_config(self.experiments_fname, [key])\n            self.assertTrue(isinstance(result, dict))\n            if \"dataset\" in result:\n                experiment_config_parser(result, config['data_dir'])\n\n    def test_config_get_filename(self):\n        self.assertEqual(Config.get_filename(), os.path.join(os.path.dirname(memcnn.config.__file__), \"config.json\"))\n\n    def test_config_get_dir(self):\n        self.assertEqual(Config.get_dir(), os.path.dirname(memcnn.config.__file__))\n\n    def test_verbose(self):\n        ConfigTestCase.ConfigTest(verbose=True)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "memcnn/data/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/data/cifar.py",
    "content": "import torch\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as transforms\nimport numpy as np\nfrom memcnn.data.sampling import NSamplesRandomSampler\nimport functools\n\n\ndef random_crop_transform(x, crop_size=3, img_size=(32, 32)):\n    cz = (crop_size + 1) // 2\n    x_pad = np.pad(x, ((cz, cz), (cz, cz), (0, 0)), mode='constant')\n    sx, sy = np.random.randint(crop_size + 1), np.random.randint(crop_size + 1)\n    return x_pad[sx:sx + img_size[0], sy:sy + img_size[1], :]\n\n\ndef tonumpy_fn(x):\n    return np.array(x.getdata()).reshape(x.size[1], x.size[0], 3)\n\n\ndef random_lr_flip_fn(x):\n    return np.copy(x[:, ::-1, :]) if np.random.random() >= 0.5 else x\n\n\ndef mean_subtract_fn(x, mean=0):\n    return x.astype(np.float32) - mean\n\n\ndef reformat_fn(x):\n    return x.transpose(2, 0, 1).astype(np.float32)\n\n\ndef get_cifar_data_loaders(dataset, data_dir, max_epoch, batch_size, workers):\n\n    train_set = dataset(root=data_dir, train=True, download=True)\n    valid_set = dataset(root=data_dir, train=False, download=True)\n\n    # calculate mean subtraction img with backwards compatibility for torchvision < 0.2.2\n    tdata = train_set.train_data if hasattr(train_set, 'train_data') else train_set.data\n    vdata = valid_set.test_data if hasattr(valid_set, 'test_data') else valid_set.data\n    mean_img = np.concatenate((tdata, vdata), axis=0).mean(axis=0)\n\n    mean_subtract_partial_fn = functools.partial(mean_subtract_fn, mean=mean_img)\n\n    # define transforms\n    randomcroplambda = transforms.Lambda(random_crop_transform)\n    tonumpy = transforms.Lambda(tonumpy_fn)\n    randomlrflip = transforms.Lambda(random_lr_flip_fn)\n    meansubtraction = transforms.Lambda(mean_subtract_partial_fn)\n    reformat = transforms.Lambda(reformat_fn)\n    totensor = transforms.Lambda(torch.from_numpy)\n    tfs = transforms.Compose([\n        tonumpy,\n        meansubtraction,\n        randomcroplambda,\n        randomlrflip,\n        reformat,\n        totensor\n    ])\n\n    train_set.transform = tfs\n    valid_set.transform = tfs\n    sampler = NSamplesRandomSampler(train_set, max_epoch * batch_size)\n\n    train_loader = DataLoader(train_set,\n                              batch_size=batch_size, shuffle=False,\n                              sampler=sampler, num_workers=workers,\n                              pin_memory=True)\n\n    val_loader = DataLoader(valid_set,\n                            batch_size=batch_size, shuffle=False,\n                            num_workers=workers, pin_memory=True)\n\n    return train_loader, val_loader\n"
  },
  {
    "path": "memcnn/data/sampling.py",
    "content": "import torch\nfrom torch.utils.data.sampler import Sampler\n\n\nclass NSamplesRandomSampler(Sampler):\n    \"\"\"Samples elements randomly, with replacement,\n    always in blocks all elements of the dataset.\n    Only the remainder will be sampled with less elements.\n\n    Arguments:\n        data_source (Dataset): dataset to sample from\n        nsamples (int): number of total samples. Note: will always be cast to int\n    \"\"\"\n\n    @property\n    def nsamples(self):\n        return self._nsamples\n\n    @nsamples.setter\n    def nsamples(self, value):\n        self._nsamples = int(value)\n\n    def __init__(self, data_source, nsamples):\n        self.data_source = data_source\n        self.nsamples = nsamples\n\n    def __iter__(self):\n        samples = torch.LongTensor()\n        len_data_source = len(self.data_source)\n        for _ in range(self.nsamples // len_data_source):\n            samples = torch.cat((samples, torch.randperm(len_data_source).long()))\n        if self.nsamples % len_data_source > 0:\n            samples = torch.cat((samples, torch.randperm(self.nsamples % len_data_source).long()))\n        return iter(samples)\n\n    def __len__(self):\n        return self.nsamples\n"
  },
  {
    "path": "memcnn/data/tests/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/data/tests/test_cifar.py",
    "content": "import pytest\nfrom memcnn.data.cifar import get_cifar_data_loaders, random_crop_transform\nimport torch.utils.data as data\nimport numpy as np\nfrom PIL import Image\n\n\n@pytest.mark.parametrize('crop_size,img_size', [(4, (32, 32)), (0, (32, 32))])\ndef test_random_crop_transform(crop_size, img_size):\n    np.random.seed(42)\n    img = np.random.random((img_size[0], img_size[1], 3))\n    imgres = random_crop_transform(img, crop_size, img_size)\n    assert imgres.shape == img.shape\n    assert imgres.dtype == img.dtype\n    if crop_size == 0:\n        assert np.array_equal(img, imgres)\n\n\n@pytest.mark.parametrize('max_epoch,batch_size', [(10, 2), (20, 4), (1, 1)])\ndef test_cifar_data_loaders(max_epoch, batch_size):\n    np.random.seed(42)\n\n    class TestDataset(data.Dataset):\n        def __init__(self, train=True, *args, **kwargs):\n            self.train = train\n            self.args = args\n            self.kwargs = kwargs\n            if self.train:\n                self.train_data = (np.random.random_sample((20, 32, 32, 3)) * 255).astype(np.uint8)\n            else:\n                self.test_data = (np.random.random_sample((10, 32, 32, 3)) * 255).astype(np.uint8)\n            self.transform = lambda val: val\n\n        def __getitem__(self, idx):\n            img = self.train_data[idx] if self.train else self.test_data[idx]\n            img = Image.fromarray(img)\n            img = self.transform(img)\n            return img, np.array(idx)\n\n        def __len__(self):\n            return len(self.train_data) if self.train else len(self.test_data)\n\n    max_epoch = 10\n    batch_size = 2\n    workers = 0\n    train_loader, val_loader = get_cifar_data_loaders(TestDataset, '', max_epoch, batch_size, workers=workers)\n\n    xsize = (batch_size, 3, 32, 32)\n    ysize = (batch_size, )\n    count = 0\n    for x, y in train_loader:\n        count += 1\n        assert x.shape == xsize\n        assert y.shape == ysize\n\n    assert count == max_epoch\n    assert count == len(train_loader)\n\n    count = 0\n    for x, y in val_loader:\n        count += 1\n        assert x.shape == xsize\n        assert y.shape == ysize\n\n    assert count == len(val_loader.dataset) // batch_size\n    assert count == len(val_loader)\n"
  },
  {
    "path": "memcnn/data/tests/test_sampling.py",
    "content": "import pytest\nfrom memcnn.data.sampling import NSamplesRandomSampler\nimport torch.utils.data as data\nimport numpy as np\n\n\n@pytest.mark.parametrize('nsamples,data_samples', [(1, 1), (14, 10), (10, 14), (5, 1), (1, 5), (0, 10),\n                                                   (np.array(4, dtype=np.int64), 12),\n                                                   (np.int64(4), 12),\n                                                   (np.array(12, dtype=np.int64), 3),\n                                                   (np.int64(12), 3)])\n@pytest.mark.parametrize('assign_after_creation', [False, True])\ndef test_random_sampler(nsamples, data_samples, assign_after_creation):\n\n    class TestDataset(data.Dataset):\n        def __init__(self, elements):\n            self.elements = elements\n\n        def __getitem__(self, idx):\n            return idx, idx\n\n        def __len__(self):\n            return self.elements\n\n    datasrc = TestDataset(data_samples)\n    sampler = NSamplesRandomSampler(datasrc, nsamples=nsamples if not assign_after_creation else -1)\n    if assign_after_creation:\n        sampler.nsamples = nsamples\n    count = 0\n    elements = []\n    for e in sampler:\n        elements.append(e)\n        count += 1\n        if count % data_samples == 0:\n            assert len(np.unique(elements)) == len(elements)\n            elements = []\n    assert count == nsamples\n    assert len(sampler) == nsamples\n    assert sampler.__len__() == nsamples\n"
  },
  {
    "path": "memcnn/examples/minimal.py",
    "content": "import torch\nimport torch.nn as nn\nimport memcnn\n\n\n# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d\nclass ExampleOperation(nn.Module):\n    def __init__(self, channels):\n        super(ExampleOperation, self).__init__()\n        self.seq = nn.Sequential(\n                                    nn.Conv2d(in_channels=channels, out_channels=channels,\n                                              kernel_size=(3, 3), padding=1),\n                                    nn.BatchNorm2d(num_features=channels),\n                                    nn.ReLU(inplace=True)\n                                )\n\n    def forward(self, x):\n        return self.seq(x)\n\n\n# generate some random input data (batch_size, num_channels, y_elements, x_elements)\nX = torch.rand(2, 10, 8, 8)\n\n# application of the operation(s) the normal way\nmodel_normal = ExampleOperation(channels=10)\nmodel_normal.eval()\n\nY = model_normal(X)\n\n# turn the ExampleOperation invertible using an additive coupling\ninvertible_module = memcnn.AdditiveCoupling(\n    Fm=ExampleOperation(channels=10 // 2),\n    Gm=ExampleOperation(channels=10 // 2)\n)\n\n# test that it is actually a valid invertible module (has a valid inverse method)\nassert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)\n\n# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training\ninvertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True)\n\n# by default the module is set to training, the following sets this to evaluation\n# note that this is required to pass input tensors to the model with requires_grad=False (inference only)\ninvertible_module_wrapper.eval()\n\n# test that the wrapped module is also a valid invertible module\nassert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape)\n\n# compute the forward pass using the wrapper\nY2 = invertible_module_wrapper.forward(X)\n\n# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2\nX2 = invertible_module_wrapper.inverse(Y2)\n\n# test that the input and approximation are similar\nassert torch.allclose(X, X2, atol=1e-06)\n"
  },
  {
    "path": "memcnn/examples/test_examples.py",
    "content": "import torch\nimport sys\n\n\ndef test_minimal():\n    import minimal\n    # Input and inversed output should be approximately the same\n    assert torch.allclose(minimal.X, minimal.X2, atol=1e-06)\n\n    # Output of the wrapped invertible module is unlikely to match the normal output of F\n    assert not torch.allclose(minimal.Y2, minimal.Y)\n\n    # Cleanup minimal module and variables\n    del minimal.X\n    del minimal.Y\n    del minimal.Y2\n    del minimal.X2\n    del minimal\n    del sys.modules['minimal']\n"
  },
  {
    "path": "memcnn/experiment/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/experiment/factory.py",
    "content": "import json\nimport copy\n\n\ndef get_attr_from_module(pclass):\n    pclass = pclass.rsplit(\".\", 1)\n    mod = __import__(pclass[0], fromlist=[str(pclass[1])])\n    return getattr(mod, pclass[1])\n\n\ndef load_experiment_config(experiments_file, experiment_tags):\n    with open(experiments_file, 'r') as f:\n        data = json.load(f)\n    d = {}\n    for tag in experiment_tags:\n        _inject_items(build_dict(data, tag), d)\n\n    return d\n\n\ndef _inject_items(tempdict, d):\n    \"\"\"inject tempdict into d\"\"\"\n    for k, v in tempdict.items():\n        if isinstance(v, dict):\n            if k not in d:\n                d[k] = {}\n            d[k] = _inject_items(v, d[k])\n        else:\n            d[k] = v\n    return d\n\n\ndef build_dict(experiments_dict, experiment_name, classhist=None):\n    tempdict = experiments_dict[experiment_name]\n    if classhist is None:\n        classhist = []\n    classhist.append(experiment_name)\n    if not ('base' in tempdict) or (tempdict['base'] is None):\n        return copy.deepcopy(tempdict)\n    elif tempdict['base'] in classhist:\n        raise RuntimeError('Circular dependency found...')\n    else:\n        d = build_dict(experiments_dict, tempdict['base'], classhist)\n        return _inject_items(tempdict, d)\n\n\ndef experiment_config_parser(d, data_dir, workers=None):\n    trainer = get_attr_from_module(d['trainer'])\n\n    model = get_attr_from_module(d['model'])\n    model_params = copy.deepcopy(d['model_params'])\n    if 'block' in model_params:\n        model_params['block'] = get_attr_from_module(model_params['block'])\n    model = model(**model_params)\n\n    optimizer = get_attr_from_module(d['optimizer'])\n    optimizer = optimizer(model.parameters(), **d['optimizer_params'])\n\n    dl_params = copy.deepcopy(d['data_loader_params'])\n    dl_params['dataset'] = get_attr_from_module(dl_params['dataset'])\n    dl_params['data_dir'] = data_dir\n    dl_params['workers'] = dl_params['workers'] if workers is None else workers\n\n    train_loader, val_loader = get_attr_from_module(d['data_loader'])(**dl_params)\n\n    trainer_params = {}\n    if 'trainer_params' in d:\n        trainer_params = copy.deepcopy(d['trainer_params'])\n        if 'loss' in trainer_params:\n            trainer_params['loss'] = get_attr_from_module(trainer_params['loss'])()\n\n    trainer_params = dict(\n        train_loader=train_loader,\n        test_loader=val_loader,\n        **trainer_params\n    )\n\n    return model, optimizer, trainer, trainer_params\n"
  },
  {
    "path": "memcnn/experiment/manager.py",
    "content": "import os\nimport glob\nimport torch\nimport logging\nimport shutil\nimport numpy as np\n\n\nclass ExperimentManager(object):\n    def __init__(self, experiment_dir, model=None, optimizer=None):\n        self.logger = logging.getLogger(type(self).__name__)\n        self.experiment_dir = experiment_dir\n        self.model = model\n        self.optimizer = optimizer\n        self.model_dir = os.path.join(self.experiment_dir, \"state\", \"model\")\n        self.optim_dir = os.path.join(self.experiment_dir, \"state\", \"optimizer\")\n        self.log_dir = os.path.join(self.experiment_dir, \"log\")\n        self.dirs = (self.experiment_dir, self.model_dir, self.log_dir, self.optim_dir)\n\n    def make_dirs(self):\n        for d in self.dirs:\n            if not os.path.exists(d):\n                os.makedirs(d)\n        assert(self.all_dirs_exists())  # nosec\n\n    def delete_dirs(self):\n        for d in self.dirs:\n            if os.path.exists(d):\n                shutil.rmtree(d)\n        assert(not self.any_dir_exists())  # nosec\n\n    def any_dir_exists(self):\n        return any([os.path.exists(d) for d in self.dirs])\n\n    def all_dirs_exists(self):\n        return all([os.path.exists(d) for d in self.dirs])\n\n    def save_model_state(self, epoch):\n        model_fname = os.path.join(self.model_dir, \"{}.pt\".format(epoch))\n        self.logger.info(\"Saving model state to: {}\".format(model_fname))\n        torch.save(self.model.state_dict(), model_fname)\n\n    def load_model_state(self, epoch):\n        model_fname = os.path.join(self.model_dir, \"{}.pt\".format(epoch))\n        self.logger.info(\"Loading model state from: {}\".format(model_fname))\n        self.model.load_state_dict(torch.load(model_fname))\n\n    def save_optimizer_state(self, epoch):\n        optim_fname = os.path.join(self.optim_dir, \"{}.pt\".format(epoch))\n        self.logger.info(\"Saving optimizer state to: {}\".format(optim_fname))\n        torch.save(self.optimizer.state_dict(), optim_fname)\n\n    def load_optimizer_state(self, epoch):\n        optim_fname = os.path.join(self.optim_dir, \"{}.pt\".format(epoch))\n        self.logger.info(\"Loading optimizer state from {}\".format(optim_fname))\n        self.optimizer.load_state_dict(torch.load(optim_fname))\n\n    def save_train_state(self, epoch):\n        self.save_model_state(epoch)\n        self.save_optimizer_state(epoch)\n\n    def load_train_state(self, epoch):\n        self.load_model_state(epoch)\n        self.load_optimizer_state(epoch)\n\n    def get_last_model_iteration(self):\n        return np.array([0] + [int(os.path.basename(e).split(\".\")[0])\n                               for e in glob.glob(os.path.join(self.model_dir, \"*.pt\"))]).max()\n\n    def load_last_train_state(self):\n        self.load_train_state(self.get_last_model_iteration())\n"
  },
  {
    "path": "memcnn/experiment/tests/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/experiment/tests/test_factory.py",
    "content": "import pytest\nimport os\nimport memcnn.experiment.factory\nfrom memcnn.config import Config\n\n\ndef test_get_attr_from_module():\n    a = memcnn.experiment.factory.get_attr_from_module('memcnn.experiment.factory.get_attr_from_module')\n    assert a is memcnn.experiment.factory.get_attr_from_module\n\n\ndef test_load_experiment_config():\n    cfg_fname = os.path.join(Config.get_dir(), 'experiments.json')\n    memcnn.experiment.factory.load_experiment_config(cfg_fname, ['cifar10', 'resnet110'])\n\n\n@pytest.mark.skip(reason=\"Covered more efficiently by test_train.test_run_experiment\")\ndef test_experiment_config_parser(tmp_path):\n    tmp_data_dir = tmp_path / \"tmpdata\"\n    cfg_fname = os.path.join(Config.get_dir(), 'experiments.json')\n    cfg = memcnn.experiment.factory.load_experiment_config(cfg_fname, ['cifar10', 'resnet110'])\n    memcnn.experiment.factory.experiment_config_parser(cfg, str(tmp_data_dir), workers=None)\n\n\ndef test_circular_dependency(tmp_path):\n    p = str(tmp_path / \"circular.json\")\n    content = u'{ \"circ\": { \"base\": \"circ\" } }'\n    with open(p, 'w') as fh:\n        fh.write(content)\n    with open(p, 'r') as fh:\n        assert fh.read() == content\n    with pytest.raises(RuntimeError):\n        memcnn.experiment.factory.load_experiment_config(p, ['circ'])\n"
  },
  {
    "path": "memcnn/experiment/tests/test_manager.py",
    "content": "from memcnn.experiment.manager import ExperimentManager\nimport torch.nn\n\n\ndef test_experiment_manager(tmp_path):\n    exp_dir = tmp_path / \"test_exp_dir\"\n    man = ExperimentManager(str(exp_dir))\n    assert man.model is None\n    assert man.optimizer is None\n\n    man.make_dirs()\n    assert exp_dir.exists()\n    assert (exp_dir / \"log\").exists()\n    assert (exp_dir / \"state\" / \"model\").exists()\n    assert (exp_dir / \"state\" / \"optimizer\").exists()\n    assert man.all_dirs_exists()\n    assert man.any_dir_exists()\n\n    man.delete_dirs()\n    assert not exp_dir.exists()\n    assert not (exp_dir / \"log\").exists()\n    assert not (exp_dir / \"state\" / \"model\").exists()\n    assert not (exp_dir / \"state\" / \"optimizer\").exists()\n    assert not man.all_dirs_exists()\n    assert not man.any_dir_exists()\n\n    man.make_dirs()\n\n    man.model = torch.nn.Conv2d(2, 1, 3)\n    w = man.model.weight.clone()\n    man.save_model_state(0)\n    with torch.no_grad():\n        man.model.weight.zero_()\n    man.save_model_state(100)\n    assert not man.model.weight.equal(w)\n    assert man.get_last_model_iteration() == 100\n\n    man.load_model_state(0)\n    assert man.model.weight.equal(w)\n\n    optimizer = torch.optim.SGD(man.model.parameters(), lr=0.01, momentum=0.1)\n    man.optimizer = optimizer\n\n    man.save_train_state(100)\n\n    w = man.model.weight.clone()\n    sd = man.optimizer.state_dict().copy()\n\n    man.model.train()\n\n    x = torch.ones(5, 2, 5, 5)\n    x.requires_grad = True\n    y = torch.ones(5, 1, 3, 3)\n    y.requires_grad = False\n\n    ypred = man.model(x)\n    loss = torch.nn.MSELoss()(ypred, y)\n    man.optimizer.zero_grad()\n    loss.backward()\n    man.optimizer.step()\n\n    man.save_train_state(101)\n    assert not man.model.weight.equal(w)\n    assert sd != man.optimizer.state_dict()\n    w2 = man.model.weight.clone()\n    sd2 = man.optimizer.state_dict().copy()\n\n    man.load_train_state(100)\n    assert man.model.weight.equal(w)\n    assert sd == man.optimizer.state_dict()\n\n    man.load_last_train_state() # should be 101\n    assert not man.model.weight.equal(w)\n    assert sd != man.optimizer.state_dict()\n    assert man.model.weight.equal(w2)\n\n    def retrieve_mom_buffer(sd):\n        keys = [e for e in sd['state'].keys()]\n        if len(keys) == 0:\n            return torch.zero(0)\n        else:\n            return sd['state'][keys[0]]['momentum_buffer']\n\n    assert torch.equal(retrieve_mom_buffer(sd2), retrieve_mom_buffer(man.optimizer.state_dict()))\n"
  },
  {
    "path": "memcnn/models/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/models/additive.py",
    "content": "import warnings\nimport torch\nimport torch.nn as nn\nimport copy\nfrom torch import set_grad_enabled\n\n\nclass AdditiveCoupling(nn.Module):\n    def __init__(self, Fm, Gm=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1):\n        \"\"\"\n        This computes the output :math:`y` on forward given input :math:`x` and arbitrary modules :math:`Fm` and :math:`Gm` according to:\n\n        :math:`(x1, x2) = x`\n\n        :math:`y1 = x1 + Fm(x2)`\n\n        :math:`y2 = x2 + Gm(y1)`\n\n        :math:`y = (y1, y2)`\n\n        Parameters\n        ----------\n            Fm : :obj:`torch.nn.Module`\n                A torch.nn.Module encapsulating an arbitrary function\n\n            Gm : :obj:`torch.nn.Module`\n                A torch.nn.Module encapsulating an arbitrary function\n                (If not specified a deepcopy of Fm is used as a Module)\n\n            implementation_fwd : :obj:`int`\n                Switch between different Additive Operation implementations for forward pass. Default = -1\n\n            implementation_bwd : :obj:`int`\n                Switch between different Additive Operation implementations for inverse pass. Default = -1\n\n            split_dim : :obj:`int`\n                Dimension to split the input tensors on. Default = 1, generally corresponding to channels.\n\n        \"\"\"\n        super(AdditiveCoupling, self).__init__()\n        # mirror the passed module, without parameter sharing...\n        if Gm is None:\n            Gm = copy.deepcopy(Fm)\n        self.Gm = Gm\n        self.Fm = Fm\n        self.implementation_fwd = implementation_fwd\n        self.implementation_bwd = implementation_bwd\n        self.split_dim = split_dim\n        if implementation_bwd != -1 or implementation_fwd != -1:\n            warnings.warn(\"Other implementations than the default (-1) are now deprecated.\",\n                          DeprecationWarning)\n\n    def forward(self, x):\n        args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]\n\n        if self.implementation_fwd == 0:\n            out = AdditiveBlockFunction.apply(*args)\n        elif self.implementation_fwd == 1:\n            out = AdditiveBlockFunction2.apply(*args)\n        elif self.implementation_fwd == -1:\n            x1, x2 = torch.chunk(x, 2, dim=self.split_dim)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n            fmd = self.Fm.forward(x2)\n            y1 = x1 + fmd\n            gmd = self.Gm.forward(y1)\n            y2 = x2 + gmd\n            out = torch.cat([y1, y2], dim=self.split_dim)\n        else:\n            raise NotImplementedError(\"Selected implementation ({}) not implemented...\"\n                                      .format(self.implementation_fwd))\n        return out\n\n    def inverse(self, y):\n        args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]\n\n        if self.implementation_bwd == 0:\n            x = AdditiveBlockInverseFunction.apply(*args)\n        elif self.implementation_bwd == 1:\n            x = AdditiveBlockInverseFunction2.apply(*args)\n        elif self.implementation_bwd == -1:\n            y1, y2 = torch.chunk(y, 2, dim=self.split_dim)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n            gmd = self.Gm.forward(y1)\n            x2 = y2 - gmd\n            fmd = self.Fm.forward(x2)\n            x1 = y1 - fmd\n            x = torch.cat([x1, x2], dim=self.split_dim)\n        else:\n            raise NotImplementedError(\"Inverse for selected implementation ({}) not implemented...\"\n                                      .format(self.implementation_bwd))\n        return x\n\n\nclass AdditiveBlock(AdditiveCoupling):\n    def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):\n        warnings.warn(\"This class has been deprecated. Use the AdditiveCoupling class instead.\",\n                      DeprecationWarning)\n        super(AdditiveBlock, self).__init__(Fm=Fm, Gm=Gm,\n                                            implementation_fwd=implementation_fwd,\n                                            implementation_bwd=implementation_bwd)\n\n\nclass AdditiveBlockFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xin, Fm, Gm, *weights):\n        \"\"\"Forward pass computes:\n        {x1, x2} = x\n        y1 = x1 + Fm(x2)\n        y2 = x2 + Gm(y1)\n        output = {y1, y2}\n\n        Parameters\n        ----------\n        ctx : torch.autograd.Function\n            The backward pass context object\n        x : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this function\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert(xin.shape[1] % 2 == 0)  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        ctx.Fm = Fm\n        ctx.Gm = Gm\n\n        with torch.no_grad():\n            x = xin.detach()\n            # partition in two equally sized set of channels\n            x1, x2 = torch.chunk(x, 2, dim=1)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n\n            # compute outputs\n            fmr = Fm.forward(x2)\n\n            y1 = x1 + fmr\n            x1.set_()\n            del x1\n            gmr = Gm.forward(y1)\n            y2 = x2 + gmr\n            x2.set_()\n            del x2\n            output = torch.cat([y1, y2], dim=1)\n\n        ctx.save_for_backward(xin, output)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):  # pragma: no cover\n        # retrieve weight references\n        Fm, Gm = ctx.Fm, ctx.Gm\n\n        # retrieve input and output references\n        xin, output = ctx.saved_tensors\n        x = xin.detach()\n        x1, x2 = torch.chunk(x, 2, dim=1)\n        GWeights = [p for p in Gm.parameters()]\n        # partition output gradient also on channels\n        assert grad_output.shape[1] % 2 == 0  # nosec\n\n        with set_grad_enabled(True):\n            # compute outputs building a sub-graph\n            x1.requires_grad_()\n            x2.requires_grad_()\n\n            y1 = x1 + Fm.forward(x2)\n            y2 = x2 + Gm.forward(y1)\n            y = torch.cat([y1, y2], dim=1)\n\n            # perform full backward pass on graph...\n            dd = torch.autograd.grad(y, (x1, x2 ) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output)\n\n            GWgrads = dd[2:2+len(GWeights)]\n            FWgrads = dd[2+len(GWeights):]\n            grad_input = torch.cat([dd[0], dd[1]], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n\n\nclass AdditiveBlockInverseFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(cty, y, Fm, Gm, *weights):\n        \"\"\"Forward pass computes:\n        {y1, y2} = y\n        x2 = y2 - Gm(y1)\n        x1 = y1 - Fm(x2)\n        output = {x1, x2}\n\n        Parameters\n        ----------\n        cty : torch.autograd.Function\n            The backward pass context object\n        y : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this fuction\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert(y.shape[1] % 2 == 0)  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        cty.Fm = Fm\n        cty.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            y1, y2 = torch.chunk(y, 2, dim=1)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n\n            # compute outputs\n            gmr = Gm.forward(y1)\n\n            x2 = y2 - gmr\n            y2.set_()\n            del y2\n            fmr = Fm.forward(x2)\n            x1 = y1 - fmr\n            y1.set_()\n            del y1\n            output = torch.cat([x1, x2], dim=1)\n            x1.set_()\n            x2.set_()\n            del x1, x2\n\n        # save the (empty) input and (non-empty) output variables\n        cty.save_for_backward(y.data, output)\n\n        return output\n\n    @staticmethod\n    def backward(cty, grad_output):  # pragma: no cover\n        # retrieve weight references\n        Fm, Gm = cty.Fm, cty.Gm\n\n        # retrieve input and output references\n        yin, output = cty.saved_tensors\n        y = yin.detach()\n        y1, y2 = torch.chunk(y, 2, dim=1)\n        FWeights = [p for p in Fm.parameters()]\n\n        # partition output gradient also on channels\n        assert grad_output.shape[1] % 2 == 0  # nosec\n\n        with set_grad_enabled(True):\n            # compute outputs building a sub-graph\n            y2.requires_grad = True\n            y1.requires_grad = True\n\n            x2 = y2 - Gm.forward(y1)\n            x1 = y1 - Fm.forward(x2)\n            x = torch.cat([x1, x2], dim=1)\n\n            # perform full backward pass on graph...\n            dd = torch.autograd.grad(x, (y2, y1 ) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output)\n\n            FWgrads = dd[2:2+len(FWeights)]\n            GWgrads = dd[2+len(FWeights):]\n            grad_input = torch.cat([dd[0], dd[1]], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n\nclass AdditiveBlockFunction2(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xin, Fm, Gm, *weights):\n        \"\"\"Forward pass computes:\n        {x1, x2} = x\n        y1 = x1 + Fm(x2)\n        y2 = x2 + Gm(y1)\n        output = {y1, y2}\n\n        Parameters\n        ----------\n        ctx : torch.autograd.Function\n            The backward pass context object\n        x : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this fuction\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert xin.shape[1] % 2 == 0  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        ctx.Fm = Fm\n        ctx.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            x = xin.detach()\n            x1, x2 = torch.chunk(x, 2, dim=1)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n\n            # compute outputs\n            fmr = Fm.forward(x2)\n\n            y1 = x1 + fmr\n            x1.set_()\n            del x1\n            gmr = Gm.forward(y1)\n            y2 = x2 + gmr\n            x2.set_()\n            del x2\n            output = torch.cat([y1, y2], dim=1).detach_()\n\n        # save the input and output variables\n        ctx.save_for_backward(x, output)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):  # pragma: no cover\n\n        Fm, Gm = ctx.Fm, ctx.Gm\n        # are all variable objects now\n        x, output = ctx.saved_tensors\n\n        with torch.no_grad():\n            y1, y2 = torch.chunk(output, 2, dim=1)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n\n            # partition output gradient also on channels\n            assert(grad_output.shape[1] % 2 == 0)  # nosec\n            y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1)\n            y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous()\n\n        # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:\n        # z1_stop, x2_stop, GW, FW\n        # Also recompute inputs (x1, x2) from outputs (y1, y2)\n        with set_grad_enabled(True):\n            z1_stop = y1.detach()\n            z1_stop.requires_grad = True\n\n            G_z1 = Gm.forward(z1_stop)\n            x2 = y2 - G_z1\n            x2_stop = x2.detach()\n            x2_stop.requires_grad = True\n\n            F_x2 = Fm.forward(x2_stop)\n            x1 = y1 - F_x2\n            x1_stop = x1.detach()\n            x1_stop.requires_grad = True\n\n            # compute outputs building a sub-graph\n            y1 = x1_stop + F_x2\n            y2 = x2_stop + G_z1\n\n            # calculate the final gradients for the weights and inputs\n            dd = torch.autograd.grad(y2, (z1_stop,) + tuple(Gm.parameters()), y2_grad, retain_graph=False)\n            z1_grad = dd[0] + y1_grad\n            GWgrads = dd[1:]\n\n            dd = torch.autograd.grad(y1, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False)\n\n            FWgrads = dd[2:]\n            x2_grad = dd[1] + y2_grad\n            x1_grad = dd[0]\n            grad_input = torch.cat([x1_grad, x2_grad], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n\n\nclass AdditiveBlockInverseFunction2(torch.autograd.Function):\n    @staticmethod\n    def forward(cty, y, Fm, Gm, *weights):\n        \"\"\"Forward pass computes:\n        {y1, y2} = y\n        x2 = y2 - Gm(y1)\n        x1 = y1 - Fm(x2)\n        output = {x1, x2}\n\n        Parameters\n        ----------\n        cty : torch.autograd.Function\n            The backward pass context object\n        y : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this fuction\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert(y.shape[1] % 2 == 0)  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        cty.Fm = Fm\n        cty.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            y1, y2 = torch.chunk(y, 2, dim=1)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n\n            # compute outputs\n            gmr = Gm.forward(y1)\n\n            x2 = y2 - gmr\n            y2.set_()\n            del y2\n            fmr = Fm.forward(x2)\n            x1 = y1 - fmr\n            y1.set_()\n            del y1\n            output = torch.cat([x1, x2], dim=1).detach_()\n\n        # save the input and output variables\n        cty.save_for_backward(y, output)\n\n        return output\n\n    @staticmethod\n    def backward(cty, grad_output):  # pragma: no cover\n\n        Fm, Gm = cty.Fm, cty.Gm\n        # are all variable objects now\n        y, output = cty.saved_tensors\n\n        with torch.no_grad():\n            x1, x2 = torch.chunk(output, 2, dim=1)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n\n            # partition output gradient also on channels\n            assert(grad_output.shape[1] % 2 == 0)  # nosec\n            x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1)\n            x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous()\n\n        # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:\n        # z1_stop, y1_stop, GW, FW\n        # Also recompute inputs (y1, y2) from outputs (x1, x2)\n        with set_grad_enabled(True):\n            z1_stop = x2.detach()\n            z1_stop.requires_grad = True\n\n            F_z1 = Fm.forward(z1_stop)\n            y1 = x1 + F_z1\n            y1_stop = y1.detach()\n            y1_stop.requires_grad = True\n\n            G_y1 = Gm.forward(y1_stop)\n            y2 = x2 + G_y1\n            y2_stop = y2.detach()\n            y2_stop.requires_grad = True\n\n            # compute outputs building a sub-graph\n            z1 = y2_stop - G_y1\n            x1 = y1_stop - F_z1\n            x2 = z1\n\n            # calculate the final gradients for the weights and inputs\n            dd = torch.autograd.grad(x1, (z1_stop,) + tuple(Fm.parameters()), x1_grad)\n            z1_grad = dd[0] + x2_grad\n            FWgrads = dd[1:]\n\n            dd = torch.autograd.grad(x2, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False)\n\n            GWgrads = dd[2:]\n            y1_grad = dd[1] + x1_grad\n            y2_grad = dd[0]\n\n            grad_input = torch.cat([y1_grad, y2_grad], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n"
  },
  {
    "path": "memcnn/models/affine.py",
    "content": "import torch\nimport torch.nn as nn\nimport copy\nimport warnings\nfrom torch import set_grad_enabled\n\nwarnings.filterwarnings(action='ignore', category=UserWarning)\n\n\nclass AffineAdapterNaive(nn.Module):\n    \"\"\" Naive Affine adapter\n\n        Outputs exp(f(x)), f(x) given f(.) and x\n    \"\"\"\n    def __init__(self, module):\n        super(AffineAdapterNaive, self).__init__()\n        self.f = module\n\n    def forward(self, x):\n        t = self.f(x)\n        s = torch.exp(t)\n        return s, t\n\n\nclass AffineAdapterSigmoid(nn.Module):\n    \"\"\" Sigmoid based affine adapter\n\n        Partitions the output h of f(x) = h into s and t by extracting every odd and even channel\n        Outputs sigmoid(s), t\n    \"\"\"\n    def __init__(self, module):\n        super(AffineAdapterSigmoid, self).__init__()\n        self.f = module\n\n    def forward(self, x):\n        h = self.f(x)\n        assert h.shape[1] % 2 == 0  # nosec\n        scale = torch.sigmoid(h[:, 1::2] + 2.0)\n        shift = h[:, 0::2]\n        return scale, shift\n\n\nclass AffineCoupling(nn.Module):\n    def __init__(self, Fm, Gm=None, adapter=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1):\n        \"\"\"\n        This computes the output :math:`y` on forward given input :math:`x` and arbitrary modules :math:`Fm` and :math:`Gm` according to:\n\n        :math:`(x1, x2) = x`\n\n        :math:`(log({s1}), t1) = Fm(x2)`\n\n        :math:`s1 = exp(log({s1}))`\n\n        :math:`y1 = s1 * x1 + t1`\n\n        :math:`(log({s2}), t2) = Gm(y1)`\n\n        :math:`s2 = exp(log({s2}))`\n\n        :math:`y2 = s2 * x2 + t2`\n\n        :math:`y = (y1, y2)`\n\n        Parameters\n        ----------\n            Fm : :obj:`torch.nn.Module`\n                A torch.nn.Module encapsulating an arbitrary function\n\n            Gm : :obj:`torch.nn.Module`\n                A torch.nn.Module encapsulating an arbitrary function\n                (If not specified a deepcopy of Gm is used as a Module)\n\n            adapter : :obj:`torch.nn.Module` class\n                An optional wrapper class A for Fm and Gm which must output\n                s, t = A(x) with shape(s) = shape(t) = shape(x)\n                s, t are respectively the scale and shift tensors for the affine coupling.\n\n            implementation_fwd : :obj:`int`\n                Switch between different Affine Operation implementations for forward pass. Default = -1\n\n            implementation_bwd : :obj:`int`\n                Switch between different Affine Operation implementations for inverse pass. Default = -1\n\n            split_dim : :obj:`int`\n                Dimension to split the input tensors on. Default = 1, generally corresponding to channels.\n        \"\"\"\n        super(AffineCoupling, self).__init__()\n        # mirror the passed module, without parameter sharing...\n        if Gm is None:\n            Gm = copy.deepcopy(Fm)\n        # apply the adapter class if it is given\n        self.Gm = adapter(Gm) if adapter is not None else Gm\n        self.Fm = adapter(Fm) if adapter is not None else Fm\n        self.implementation_fwd = implementation_fwd\n        self.implementation_bwd = implementation_bwd\n        self.split_dim = split_dim\n        if implementation_bwd != -1 or implementation_fwd != -1:\n            warnings.warn(\"Other implementations than the default (-1) are now deprecated.\",\n                          DeprecationWarning)\n\n    def forward(self, x):\n        args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]\n\n        if self.implementation_fwd == 0:\n            out = AffineBlockFunction.apply(*args)\n        elif self.implementation_fwd == 1:\n            out = AffineBlockFunction2.apply(*args)\n        elif self.implementation_fwd == -1:\n            x1, x2 = torch.chunk(x, 2, dim=self.split_dim)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n            fmr1, fmr2 = self.Fm.forward(x2)\n            y1 = (x1 * fmr1) + fmr2\n            gmr1, gmr2 = self.Gm.forward(y1)\n            y2 = (x2 * gmr1) + gmr2\n            out = torch.cat([y1, y2], dim=self.split_dim)\n        else:\n            raise NotImplementedError(\"Selected implementation ({}) not implemented...\"\n                                      .format(self.implementation_fwd))\n        return out\n\n    def inverse(self, y):\n        args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]\n\n        if self.implementation_bwd == 0:\n            x = AffineBlockInverseFunction.apply(*args)\n        elif self.implementation_bwd == 1:\n            x = AffineBlockInverseFunction2.apply(*args)\n        elif self.implementation_bwd == -1:\n            y1, y2 = torch.chunk(y, 2, dim=self.split_dim)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n            gmr1, gmr2 = self.Gm.forward(y1)\n            x2 = (y2 - gmr2) / gmr1\n            fmr1, fmr2 = self.Fm.forward(x2)\n            x1 = (y1 - fmr2) / fmr1\n            x = torch.cat([x1, x2], dim=self.split_dim)\n        else:\n            raise NotImplementedError(\"Inverse for selected implementation ({}) not implemented...\"\n                                      .format(self.implementation_bwd))\n        return x\n\n\nclass AffineBlock(AffineCoupling):\n    def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1):\n        warnings.warn(\"This class has been deprecated. Use the AffineCoupling class instead.\",\n                      DeprecationWarning)\n        super(AffineBlock, self).__init__(Fm=Fm, Gm=Gm,\n                                          implementation_fwd=implementation_fwd,\n                                          implementation_bwd=implementation_bwd)\n\n\nclass AffineBlockFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xin, Fm, Gm, *weights):\n        \"\"\"Forward pass for the affine block computes:\n        {x1, x2} = x\n        {log_s1, t1} = Fm(x2)\n        s1 = exp(log_s1)\n        y1 = s1 * x1 + t1\n        {log_s2, t2} = Gm(y1)\n        s2 = exp(log_s2)\n        y2 = s2 * x2 + t2\n        output = {y1, y2}\n\n        Parameters\n        ----------\n        ctx : torch.autograd.function.RevNetFunctionBackward\n            The backward pass context object\n        x : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this function\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert xin.shape[1] % 2 == 0  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        ctx.Fm = Fm\n        ctx.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            x = xin.detach()\n            x1, x2 = torch.chunk(x, 2, dim=1)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n\n            # compute outputs\n            x2var = x2\n            fmr1, fmr2 = Fm.forward(x2var)\n\n            y1 = (x1 * fmr1) + fmr2\n            x1.set_()\n            del x1\n            y1var = y1\n            gmr1, gmr2 = Gm.forward(y1var)\n            y2 = (x2 * gmr1) + gmr2\n            x2.set_()\n            del x2\n            output = torch.cat([y1, y2], dim=1).detach_()\n\n        # save the (empty) input and (non-empty) output variables\n        ctx.save_for_backward(xin, output)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):  # pragma: no cover\n        # retrieve weight references\n        Fm, Gm = ctx.Fm, ctx.Gm\n\n        # retrieve input and output references\n        xin, output = ctx.saved_tensors\n        x = xin.detach()\n        x1, x2 = torch.chunk(x.detach(), 2, dim=1)\n        GWeights = [p for p in Gm.parameters()]\n\n        # partition output gradient also on channels\n        assert (grad_output.shape[1] % 2 == 0)  # nosec\n\n        with set_grad_enabled(True):\n            # compute outputs building a sub-graph\n            x1.requires_grad = True\n            x2.requires_grad = True\n\n            fmr1, fmr2 = Fm.forward(x2)\n            y1 = x1 * fmr1 + fmr2\n            gmr1, gmr2 = Gm.forward(y1)\n            y2 = x2 * gmr1 + gmr2\n            y = torch.cat([y1, y2], dim=1)\n\n            # perform full backward pass on graph...\n            dd = torch.autograd.grad(y, (x1, x2) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output)\n\n            GWgrads = dd[2:2 + len(GWeights)]\n            FWgrads = dd[2 + len(GWeights):]\n            grad_input = torch.cat([dd[0], dd[1]], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n\n\nclass AffineBlockInverseFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(cty, yin, Fm, Gm, *weights):\n        \"\"\"Forward inverse pass for the affine block computes:\n        {y1, y2} = y\n        {log_s2, t2} = Gm(y1)\n        s2 = exp(log_s2)\n        x2 = (y2 - t2) / s2\n        {log_s1, t1} = Fm(x2)\n        s1 = exp(log_s1)\n        x1 = (y1 - t1) / s1\n        output = {x1, x2}\n\n        Parameters\n        ----------\n        cty : torch.autograd.function.RevNetInverseFunctionBackward\n            The backward pass context object\n        y : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this fuction\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert yin.shape[1] % 2 == 0  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        cty.Fm = Fm\n        cty.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            y = yin.detach()\n            y1, y2 = torch.chunk(y, 2, dim=1)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n\n            # compute outputs\n            y1var = y1\n\n            gmr1, gmr2 = Gm.forward(y1var)\n\n            x2 = (y2 - gmr2) / gmr1\n            y2.set_()\n            del y2\n            x2var = x2\n            fmr1, fmr2 = Fm.forward(x2var)\n\n            x1 = (y1 - fmr2) / fmr1\n            y1.set_()\n            del y1\n            output = torch.cat([x1, x2], dim=1).detach_()\n\n        # save input and output variables\n        cty.save_for_backward(yin, output)\n\n        return output\n\n    @staticmethod\n    def backward(cty, grad_output):  # pragma: no cover\n        # retrieve weight references\n        Fm, Gm = cty.Fm, cty.Gm\n\n        # retrieve input and output references\n        yin, output = cty.saved_tensors\n        y = yin.detach()\n        y1, y2 = torch.chunk(y.detach(), 2, dim=1)\n        FWeights = [p for p in Gm.parameters()]\n\n        # partition output gradient also on channels\n        assert grad_output.shape[1] % 2 == 0  # nosec\n\n        with set_grad_enabled(True):\n            # compute outputs building a sub-graph\n            y2.requires_grad = True\n            y1.requires_grad = True\n\n            gmr1, gmr2 = Gm.forward(y1)  #\n            x2 = (y2 - gmr2) / gmr1\n            fmr1, fmr2 = Fm.forward(x2)\n            x1 = (y1 - fmr2) / fmr1\n            x = torch.cat([x1, x2], dim=1)\n\n            # perform full backward pass on graph...\n            dd = torch.autograd.grad(x, (y2, y1) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output)\n\n            FWgrads = dd[2:2 + len(FWeights)]\n            GWgrads = dd[2 + len(FWeights):]\n            grad_input = torch.cat([dd[0], dd[1]], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n\n\nclass AffineBlockFunction2(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xin, Fm, Gm, *weights):\n        \"\"\"Forward pass for the affine block computes:\n        {x1, x2} = x\n        {log_s1, t1} = Fm(x2)\n        s1 = exp(log_s1)\n        y1 = s1 * x1 + t1\n        {log_s2, t2} = Gm(y1)\n        s2 = exp(log_s2)\n        y2 = s2 * x2 + t2\n        output = {y1, y2}\n\n        Parameters\n        ----------\n        ctx : torch.autograd.function.RevNetFunctionBackward\n            The backward pass context object\n        x : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this fuction\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert xin.shape[1] % 2 == 0  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        ctx.Fm = Fm\n        ctx.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            x = xin.detach()\n            x1, x2 = torch.chunk(x, 2, dim=1)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n\n            # compute outputs\n            x2var = x2\n            fmr1, fmr2 = Fm.forward(x2var)\n\n            y1 = x1 * fmr1 + fmr2\n            x1.set_()\n            del x1\n            y1var = y1\n            gmr1, gmr2 = Gm.forward(y1var)\n            y2 = x2 * gmr1 + gmr2\n            x2.set_()\n            del x2\n            output = torch.cat([y1, y2], dim=1).detach_()\n\n        # save the input and output variables\n        ctx.save_for_backward(xin, output)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):  # pragma: no cover\n        Fm, Gm = ctx.Fm, ctx.Gm\n        # are all variable objects now\n        x, output = ctx.saved_tensors\n\n        with set_grad_enabled(False):\n            y1, y2 = torch.chunk(output, 2, dim=1)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n\n            # partition output gradient also on channels\n            assert (grad_output.shape[1] % 2 == 0)  # nosec\n            y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1)\n            y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous()\n\n        # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:\n        # z1_stop, x2_stop, GW, FW\n        # Also recompute inputs (x1, x2) from outputs (y1, y2)\n        with set_grad_enabled(True):\n            z1_stop = y1\n            z1_stop.requires_grad = True\n\n            G_z11, G_z12 = Gm.forward(z1_stop)\n            x2 = (y2 - G_z12) / G_z11\n            x2_stop = x2.detach()\n            x2_stop.requires_grad = True\n\n            F_x21, F_x22 = Fm.forward(x2_stop)\n            x1 = (y1 - F_x22) / F_x21\n            x1_stop = x1.detach()\n            x1_stop.requires_grad = True\n\n            # compute outputs building a sub-graph\n            z1 = x1_stop * F_x21 + F_x22\n            y2_ = x2_stop * G_z11 + G_z12\n            y1_ = z1\n\n            # calculate the final gradients for the weights and inputs\n            dd = torch.autograd.grad(y2_, (z1_stop,) + tuple(Gm.parameters()), y2_grad)\n            z1_grad = dd[0] + y1_grad\n            GWgrads = dd[1:]\n\n            dd = torch.autograd.grad(y1_, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False)\n\n            FWgrads = dd[2:]\n            x2_grad = dd[1] + y2_grad\n            x1_grad = dd[0]\n            grad_input = torch.cat([x1_grad, x2_grad], dim=1)\n\n            y1_.detach_()\n            y2_.detach_()\n            del y1_, y2_\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n\n\nclass AffineBlockInverseFunction2(torch.autograd.Function):\n    @staticmethod\n    def forward(cty, yin, Fm, Gm, *weights):\n        \"\"\"Forward pass for the affine block computes:\n\n        Parameters\n        ----------\n        cty : torch.autograd.function.RevNetInverseFunctionBackward\n            The backward pass context object\n        y : TorchTensor\n            Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions\n        Fm : nn.Module\n            Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape\n        Gm : nn.Module\n            Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape\n        *weights : TorchTensor\n            weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn}\n\n        Note\n        ----\n        All tensor/autograd variable input arguments and the output are\n        TorchTensors for the scope of this fuction\n\n        \"\"\"\n        # check if possible to partition into two equally sized partitions\n        assert yin.shape[1] % 2 == 0  # nosec\n\n        # store partition size, Fm and Gm functions in context\n        cty.Fm = Fm\n        cty.Gm = Gm\n\n        with torch.no_grad():\n            # partition in two equally sized set of channels\n            y = yin.detach()\n            y1, y2 = torch.chunk(y, 2, dim=1)\n            y1, y2 = y1.contiguous(), y2.contiguous()\n\n            # compute outputs\n            y1var = y1\n            gmr1, gmr2 = Gm.forward(y1var)\n\n            x2 = (y2 - gmr2) / gmr1\n            y2.set_()\n            del y2\n            x2var = x2\n            fmr1, fmr2 = Fm.forward(x2var)\n            x1 = (y1 - fmr2) / fmr1\n            y1.set_()\n            del y1\n            output = torch.cat([x1, x2], dim=1).detach_()\n\n        # save the input and output variables\n        cty.save_for_backward(yin, output)\n\n        return output\n\n    @staticmethod\n    def backward(cty, grad_output):  # pragma: no cover\n        Fm, Gm = cty.Fm, cty.Gm\n        # are all variable objects now\n        y, output = cty.saved_tensors\n\n        with set_grad_enabled(False):\n            x1, x2 = torch.chunk(output, 2, dim=1)\n            x1, x2 = x1.contiguous(), x2.contiguous()\n\n            # partition output gradient also on channels\n            assert (grad_output.shape[1] % 2 == 0)  # nosec\n            x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1)\n            x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous()\n\n        # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:\n        # z1_stop, y1_stop, GW, FW\n        # Also recompute inputs (y1, y2) from outputs (x1, x2)\n        with set_grad_enabled(True):\n            z1_stop = x2\n            z1_stop.requires_grad = True\n\n            F_z11, F_z12 = Fm.forward(z1_stop)\n            y1 = x1 * F_z11 + F_z12\n            y1_stop = y1.detach()\n            y1_stop.requires_grad = True\n\n            G_y11, G_y12 = Gm.forward(y1_stop)\n            y2 = x2 * G_y11 + G_y12\n            y2_stop = y2.detach()\n            y2_stop.requires_grad = True\n\n            # compute outputs building a sub-graph\n            z1 = (y2_stop - G_y12) / G_y11\n            x1_ = (y1_stop - F_z12) / F_z11\n            x2_ = z1\n\n            # calculate the final gradients for the weights and inputs\n            dd = torch.autograd.grad(x1_, (z1_stop,) + tuple(Fm.parameters()), x1_grad)\n            z1_grad = dd[0] + x2_grad\n            FWgrads = dd[1:]\n\n            dd = torch.autograd.grad(x2_, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False)\n\n            GWgrads = dd[2:]\n            y1_grad = dd[1] + x1_grad\n            y2_grad = dd[0]\n\n            grad_input = torch.cat([y1_grad, y2_grad], dim=1)\n\n        return (grad_input, None, None) + FWgrads + GWgrads\n"
  },
  {
    "path": "memcnn/models/resnet.py",
    "content": "\"\"\"ResNet/RevNet implementation used for The Reversible Residual Network\nImplemented in PyTorch instead of TensorFlow.\n\n@inproceedings{gomez17revnet,\n  author    = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},\n  title     = {The Reversible Residual Network: Backpropagation without Storing Activations}\n  booktitle = {NIPS},\n  year      = {2017},\n}\n\nGithub: https://github.com/renmengye/revnet-public\n\nAuthor: Sil van de Leemput\n\n\"\"\"\nimport torch.nn as nn\nimport math\nfrom memcnn.models.revop import InvertibleModuleWrapper, create_coupling\n\n__all__ = ['ResNet', 'BasicBlock', 'Bottleneck', 'RevBasicBlock', 'RevBottleneck', 'BasicBlockSub', 'BottleneckSub',\n           'conv3x3', 'batch_norm']\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\n\ndef batch_norm(x):\n    \"\"\"match Tensorflow batch norm settings\"\"\"\n    return nn.BatchNorm2d(x, momentum=0.99, eps=0.001)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False):\n        super(BasicBlock, self).__init__()\n        self.basicblock_sub = BasicBlockSub(inplanes, planes, stride, noactivation)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n        out = self.basicblock_sub(x)\n        if self.downsample is not None:\n            residual = self.downsample(x)\n        out += residual\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False):\n        super(Bottleneck, self).__init__()\n        self.bottleneck_sub = BottleneckSub(inplanes, planes, stride, noactivation)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n        out = self.bottleneck_sub(x)\n        if self.downsample is not None:\n            residual = self.downsample(x)\n        out += residual\n        return out\n\n\nclass RevBasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False):\n        super(RevBasicBlock, self).__init__()\n        if downsample is None and stride == 1:\n            gm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation)\n            fm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation)\n            coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive')\n            self.revblock = InvertibleModuleWrapper(fn=coupling, keep_input=False)\n        else:\n            self.basicblock_sub = BasicBlockSub(inplanes, planes, stride, noactivation)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        if self.downsample is not None:\n            out = self.basicblock_sub(x)\n            residual = self.downsample(x)\n            out += residual\n        else:\n            out = self.revblock(x)\n        return out\n\n\nclass RevBottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False):\n        super(RevBottleneck, self).__init__()\n        if downsample is None and stride == 1:\n            gm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation)\n            fm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation)\n            coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive')\n            self.revblock = InvertibleModuleWrapper(fn=coupling, keep_input=False)\n        else:\n            self.bottleneck_sub = BottleneckSub(inplanes, planes, stride, noactivation)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        if self.downsample is not None:\n            out = self.bottleneck_sub(x)\n            residual = self.downsample(x)\n            out += residual\n        else:\n            out = self.revblock(x)\n        return out\n\n\nclass BottleneckSub(nn.Module):\n    def __init__(self, inplanes, planes, stride=1, noactivation=False):\n        super(BottleneckSub, self).__init__()\n        self.noactivation = noactivation\n        if not self.noactivation:\n            self.bn1 = batch_norm(inplanes)\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn2 = batch_norm(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn3 = batch_norm(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        if not self.noactivation:\n            x = self.bn1(x)\n            x = self.relu(x)\n        x = self.conv1(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        x = self.bn3(x)\n        x = self.relu(x)\n        x = self.conv3(x)\n        return x\n\n\nclass BasicBlockSub(nn.Module):\n    def __init__(self, inplanes, planes, stride=1, noactivation=False):\n        super(BasicBlockSub, self).__init__()\n        self.noactivation = noactivation\n        if not self.noactivation:\n            self.bn1 = batch_norm(inplanes)\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn2 = batch_norm(planes)\n        self.conv2 = conv3x3(planes, planes)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        if not self.noactivation:\n            x = self.bn1(x)\n            x = self.relu(x)\n        x = self.conv1(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n        return x\n\n\nclass ResNet(nn.Module):\n    def __init__(self, block, layers, num_classes=1000, channels_per_layer=None, strides=None,\n                 init_max_pool=False, init_kernel_size=7, batch_norm_fix=True, implementation=0):\n        if channels_per_layer is None:\n            channels_per_layer = [2 ** (i + 6) for i in range(len(layers))]\n            channels_per_layer = [channels_per_layer[0]] + channels_per_layer\n        if strides is None:\n            strides = [2] * len(channels_per_layer)\n        self.batch_norm_fix = batch_norm_fix\n        self.channels_per_layer = channels_per_layer\n        self.strides = strides\n        self.init_max_pool = init_max_pool\n        self.implementation = implementation\n        assert(len(self.channels_per_layer) == len(layers) + 1)  # nosec\n        self.inplanes = channels_per_layer[0]  # 64 by default\n        super(ResNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=init_kernel_size,\n                               stride=strides[0], padding=(init_kernel_size - 1) // 2,\n                               bias=False)\n        self.bn1 = batch_norm(self.inplanes)\n        self.relu = nn.ReLU(inplace=False)\n        if self.init_max_pool:\n            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, channels_per_layer[1], layers[0], stride=strides[1], noactivation=True)\n        self.layer2 = self._make_layer(block, channels_per_layer[2], layers[1], stride=strides[2])\n        self.layer3 = self._make_layer(block, channels_per_layer[3], layers[2], stride=strides[3])\n        self.has_4_layers = len(layers) >= 4\n        if self.has_4_layers:\n            self.layer4 = self._make_layer(block, channels_per_layer[4], layers[3], stride=strides[4])\n        self.bn_final = batch_norm(self.inplanes)  # channels_per_layer[-1])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(channels_per_layer[-1] * block.expansion, num_classes)\n\n        self.configure()\n        self.init_weights()\n\n    def init_weights(self):\n        \"\"\"Initialization using He initialization\"\"\"\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.reset_parameters()\n\n    def configure(self):\n        \"\"\"Initialization specific configuration settings\"\"\"\n        for m in self.modules():\n            if isinstance(m, InvertibleModuleWrapper):\n                m.implementation = self.implementation\n            elif isinstance(m, nn.BatchNorm2d):\n                if self.batch_norm_fix:\n                    m.momentum = 0.99\n                    m.eps = 0.001\n                else:\n                    m.momentum = 0.1\n                    m.eps = 1e-05\n\n    def _make_layer(self, block, planes, blocks, stride=1, noactivation=False):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                batch_norm(planes * block.expansion),\n            )\n        layers = [block(self.inplanes, planes, stride, downsample, noactivation)]\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        if self.init_max_pool:\n            x = self.maxpool(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        if self.has_4_layers:\n            x = self.layer4(x)\n        x = self.bn_final(x)\n        x = self.relu(x)\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n        return x\n"
  },
  {
    "path": "memcnn/models/revop.py",
    "content": "import functools\nimport warnings\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom memcnn.models.additive import AdditiveCoupling\nfrom memcnn.models.affine import AffineCoupling\n\n\ntry:\n    import torch.amp\n    def custom_fwd(fwd=None, *, cast_inputs=None, device_type='cuda'):\n        if fwd is None:\n            return functools.partial(custom_fwd, cast_inputs=cast_inputs, device_type=device_type)\n        return torch.amp.custom_fwd(fwd, cast_inputs=cast_inputs, device_type=device_type)\n    def custom_bwd(bwd, device_type='cuda'):\n        return torch.amp.custom_bwd(bwd, device_type=device_type)\n\nexcept ModuleNotFoundError:\n    def custom_fwd(fwd=None, *, cast_inputs=None, device_type='cuda'):\n        if fwd is None:\n            return functools.partial(custom_fwd, cast_inputs=cast_inputs, device_type=device_type)\n        return functools.partial(fwd)\n\n    def custom_bwd(bwd, device_type='cuda'):\n        return functools.partial(bwd, device_type=device_type)\n\n\n\nclass InvertibleCheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, fn, fn_inverse, keep_input, num_bwd_passes, preserve_rng_state, num_inputs, *inputs_and_weights):\n        # store in context\n        ctx.fn = fn\n        ctx.fn_inverse = fn_inverse\n        ctx.keep_input = keep_input\n        ctx.weights = inputs_and_weights[num_inputs:]\n        ctx.num_bwd_passes = num_bwd_passes\n        ctx.preserve_rng_state = preserve_rng_state\n        ctx.num_inputs = num_inputs\n        inputs = inputs_and_weights[:num_inputs]\n\n        if preserve_rng_state:\n            ctx.fwd_cpu_state = torch.get_rng_state()\n            # Don't eagerly initialize the cuda context by accident.\n            # (If the user intends that the context is initialized later, within their\n            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,\n            # we have no way to anticipate this will happen before we run the function.)\n            ctx.had_cuda_in_fwd = False\n            if torch.cuda._initialized:\n                ctx.had_cuda_in_fwd = True\n                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*inputs)\n\n        ctx.input_requires_grad = [element.requires_grad for element in inputs]\n\n        with torch.no_grad():\n            # Makes a detached copy which shares the storage\n            x = [element.detach() for element in inputs]\n            outputs = ctx.fn(*x)\n\n        if not isinstance(outputs, tuple):\n            outputs = (outputs,)\n\n        # Detaches y in-place (inbetween computations can now be discarded)\n        detached_outputs = tuple([element.detach_() for element in outputs])\n\n        # clear memory from inputs\n        if not ctx.keep_input:\n            # PyTorch 1.0+ way to clear storage\n            for element in inputs:\n                element.storage().resize_(0)\n\n        # store these tensor nodes for backward pass\n        ctx.inputs = [inputs] * num_bwd_passes\n        ctx.outputs = [detached_outputs] * num_bwd_passes\n\n        return detached_outputs\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, *grad_outputs):  # pragma: no cover\n        if not torch.autograd._is_checkpoint_valid():\n            raise RuntimeError(\"InvertibleCheckpointFunction is not compatible with .grad(), please use .backward() if possible\")\n        # retrieve input and output tensor nodes\n        if len(ctx.outputs) == 0:\n            raise RuntimeError(\"Trying to perform backward on the InvertibleCheckpointFunction for more than \"\n                               \"{} times! Try raising `num_bwd_passes` by one.\".format(ctx.num_bwd_passes))\n        inputs = ctx.inputs.pop()\n        outputs = ctx.outputs.pop()\n\n        # recompute input if necessary\n        if not ctx.keep_input:\n            # Stash the surrounding rng state, and mimic the state that was\n            # present at this time during forward.  Restore the surrounding state\n            # when we're done.\n            rng_devices = []\n            if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:\n                rng_devices = ctx.fwd_gpu_devices\n            with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):\n                if ctx.preserve_rng_state:\n                    torch.set_rng_state(ctx.fwd_cpu_state)\n                    if ctx.had_cuda_in_fwd:\n                        set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)\n                # recompute input\n                with torch.no_grad():\n                    inputs_inverted = ctx.fn_inverse(*outputs)\n                    if not isinstance(inputs_inverted, tuple):\n                        inputs_inverted = (inputs_inverted,)\n                    for element_original, element_inverted in zip(inputs, inputs_inverted):\n                        element_original.storage().resize_(int(np.prod(element_original.size())))\n                        element_original.set_(element_inverted)\n\n        # compute gradients\n        with torch.set_grad_enabled(True):\n            detached_inputs = tuple([element.detach().requires_grad_() for element in inputs])\n            temp_output = ctx.fn(*detached_inputs)\n        if not isinstance(temp_output, tuple):\n            temp_output = (temp_output,)\n\n        gradients = torch.autograd.grad(outputs=temp_output, inputs=detached_inputs + ctx.weights, grad_outputs=grad_outputs)\n\n        # Setting the gradients manually on the inputs and outputs (mimic backwards)\n        for element, element_grad in zip(inputs, gradients[:ctx.num_inputs]):\n            element.grad = element_grad\n\n        for element, element_grad in zip(outputs, grad_outputs):\n            element.grad = element_grad\n\n        return (None, None, None, None, None, None) + gradients\n\n\nclass InvertibleModuleWrapper(nn.Module):\n    def __init__(self, fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1,\n                 disable=False, preserve_rng_state=False):\n        \"\"\"\n        The InvertibleModuleWrapper which enables memory savings during training by exploiting\n        the invertible properties of the wrapped module.\n\n        Parameters\n        ----------\n            fn : :obj:`torch.nn.Module`\n                A torch.nn.Module which has a forward and an inverse function implemented with\n                :math:`x == m.inverse(m.forward(x))`\n\n            keep_input : :obj:`bool`, optional\n                Set to retain the input information on forward, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n            keep_input_inverse : :obj:`bool`, optional\n                Set to retain the input information on inverse, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n            num_bwd_passes :obj:`int`, optional\n                Number of backward passes to retain a link with the output. After the last backward pass the output\n                is discarded and memory is freed.\n                Warning: if this value is raised higher than the number of required passes memory will not be freed\n                correctly anymore and the training process can quickly run out of memory.\n                Hence, The typical use case is to keep this at 1, until it raises an error for raising this value.\n\n            disable : :obj:`bool`, optional\n                This will disable using the InvertibleCheckpointFunction altogether.\n                Essentially this renders the function as :math:`y = fn(x)` without any of the memory savings.\n                Setting this to true will also ignore the keep_input and keep_input_inverse properties.\n\n            preserve_rng_state : :obj:`bool`, optional\n                Setting this will ensure that the same RNG state is used during reconstruction of the inputs.\n                I.e. if keep_input = False on forward or keep_input_inverse = False on inverse. By default\n                this is False since most invertible modules should have a valid inverse and hence are\n                deterministic.\n\n        Attributes\n        ----------\n            keep_input : :obj:`bool`, optional\n                Set to retain the input information on forward, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n            keep_input_inverse : :obj:`bool`, optional\n                Set to retain the input information on inverse, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n        Note\n        ----\n            The InvertibleModuleWrapper can be used with mixed-precision training using\n            :obj:`torch.cuda.amp.autocast` as of torch v1.6 and above. However, inputs will always be cast\n            to :obj:`torch.float32` internally. This is done to minimize autocasting inputs to a different datatype\n            which usually results in a disconnected computation graph and will raise an error on the backward pass.\n\n        \"\"\"\n        super(InvertibleModuleWrapper, self).__init__()\n        self.disable = disable\n        self.keep_input = keep_input\n        self.keep_input_inverse = keep_input_inverse\n        self.num_bwd_passes = num_bwd_passes\n        self.preserve_rng_state = preserve_rng_state\n        self._fn = fn\n\n    def forward(self, *xin):\n        \"\"\"Forward operation :math:`R(x) = y`\n\n        Parameters\n        ----------\n            *xin : :obj:`torch.Tensor` tuple\n                Input torch tensor(s).\n\n        Returns\n        -------\n            :obj:`torch.Tensor` tuple\n                Output torch tensor(s) *y.\n\n        \"\"\"\n        if not self.disable:\n            y = InvertibleCheckpointFunction.apply(\n                self._fn.forward,\n                self._fn.inverse,\n                self.keep_input,\n                self.num_bwd_passes,\n                self.preserve_rng_state,\n                len(xin),\n                *(xin + tuple([p for p in self._fn.parameters() if p.requires_grad])))\n        else:\n            y = self._fn(*xin)\n\n        # If the layer only has one input, we unpack the tuple again\n        if isinstance(y, tuple) and len(y) == 1:\n            return y[0]\n        return y\n\n    def inverse(self, *yin):\n        \"\"\"Inverse operation :math:`R^{-1}(y) = x`\n\n        Parameters\n        ----------\n            *yin : :obj:`torch.Tensor` tuple\n                Input torch tensor(s).\n\n        Returns\n        -------\n            :obj:`torch.Tensor` tuple\n                Output torch tensor(s) *x.\n\n        \"\"\"\n        if not self.disable:\n            x = InvertibleCheckpointFunction.apply(\n                self._fn.inverse,\n                self._fn.forward,\n                self.keep_input_inverse,\n                self.num_bwd_passes,\n                self.preserve_rng_state,\n                len(yin),\n                *(yin + tuple([p for p in self._fn.parameters() if p.requires_grad])))\n        else:\n            x = self._fn.inverse(*yin)\n\n        # If the layer only has one input, we unpack the tuple again\n        if isinstance(x, tuple) and len(x) == 1:\n            return x[0]\n        return x\n\n\nclass ReversibleBlock(InvertibleModuleWrapper):\n    def __init__(self, Fm, Gm=None, coupling='additive', keep_input=False, keep_input_inverse=False,\n                 implementation_fwd=-1, implementation_bwd=-1, adapter=None):\n        \"\"\"The ReversibleBlock\n\n        Warning\n        -------\n        This class has been deprecated. Use the more flexible InvertibleModuleWrapper class.\n\n        Note\n        ----\n        The `implementation_fwd` and `implementation_bwd` parameters can be set to one of the following implementations:\n\n        * -1 Naive implementation without reconstruction on the backward pass.\n        * 0  Memory efficient implementation, compute gradients directly.\n        * 1  Memory efficient implementation, similar to approach in Gomez et al. 2017.\n\n\n        Parameters\n        ----------\n            Fm : :obj:`torch.nn.Module`\n                A torch.nn.Module encapsulating an arbitrary function\n\n            Gm : :obj:`torch.nn.Module`, optional\n                A torch.nn.Module encapsulating an arbitrary function\n                (If not specified a deepcopy of Fm is used as a Module)\n\n            coupling : :obj:`str`, optional\n                Type of coupling ['additive', 'affine']. Default = 'additive'\n\n            keep_input : :obj:`bool`, optional\n                Set to retain the input information on forward, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n            keep_input_inverse : :obj:`bool`, optional\n                Set to retain the input information on inverse, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n            implementation_fwd : :obj:`int`, optional\n                Switch between different Operation implementations for forward training (Default = 1).\n                If using the naive implementation (-1) then `keep_input` should be True.\n\n            implementation_bwd : :obj:`int`, optional\n                Switch between different Operation implementations for backward training (Default = 1).\n                If using the naive implementation (-1) then `keep_input_inverse` should be True.\n\n            adapter : :obj:`class`, optional\n                Only relevant when using the 'affine' coupling.\n                Should be a class of type :obj:`torch.nn.Module` that serves as an\n                optional wrapper class A for Fm and Gm which must output\n                s, t = A(x) with shape(s) = shape(t) = shape(x).\n                s, t are respectively the scale and shift tensors for the affine coupling.\n\n        Attributes\n        ----------\n            keep_input : :obj:`bool`, optional\n                Set to retain the input information on forward, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n            keep_input_inverse : :obj:`bool`, optional\n                Set to retain the input information on inverse, by default it can be discarded since it will be\n                reconstructed upon the backward pass.\n\n        Raises\n        ------\n        NotImplementedError\n            If an unknown coupling or implementation is given.\n\n        \"\"\"\n        warnings.warn(\"This class has been deprecated. Use the more flexible InvertibleModuleWrapper class\", DeprecationWarning)\n        fn = create_coupling(Fm=Fm, Gm=Gm, coupling=coupling,\n                             implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd,\n                             adapter=adapter)\n        super(ReversibleBlock, self).__init__(fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse)\n\n\ndef create_coupling(Fm, Gm=None, coupling='additive', implementation_fwd=-1, implementation_bwd=-1, adapter=None):\n    if coupling == 'additive':\n        fn = AdditiveCoupling(Fm, Gm,\n                              implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)\n    elif coupling == 'affine':\n        fn = AffineCoupling(Fm, Gm, adapter=adapter,\n                            implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)\n    else:\n        raise NotImplementedError('Unknown coupling method: %s' % coupling)\n    return fn\n\n\ndef is_invertible_module(module_in, test_input_shape, test_input_dtype=torch.float32, atol=1e-6, random_seed=42):\n    \"\"\"Test if a :obj:`torch.nn.Module` is invertible\n\n    Parameters\n    ----------\n    module_in : :obj:`torch.nn.Module`\n        A torch.nn.Module to test.\n    test_input_shape : :obj:`tuple` of :obj:`int` or :obj:`tuple` of :obj:`tuple` of :obj:`int`\n        Dimensions of test tensor(s) object to perform the test with.\n    test_input_dtype : :obj:`torch.dtype`, optional\n        Data type of test tensor object to perform the test with.\n    atol : :obj:`float`, optional\n        Tolerance value used for comparing the outputs.\n    random_seed : :obj:`int`, optional\n        Use this value to seed the pseudo-random test_input_shapes with different numbers.\n\n    Returns\n    -------\n        :obj:`bool`\n            True if the input module is invertible, False otherwise.\n\n    \"\"\"\n    if isinstance(module_in, InvertibleModuleWrapper):\n        module_in = module_in._fn\n\n    if not hasattr(module_in, \"inverse\"):\n        return False\n\n    def _type_check_input_shape(test_input_shape):\n        if isinstance(test_input_shape, (tuple, list)):\n            if all([isinstance(e, int) for e in test_input_shape]):\n                return True\n            elif all([isinstance(e, (tuple, list)) for e in test_input_shape]):\n                return all([isinstance(ee, int) for e in test_input_shape for ee in e])\n            else:\n                return False\n        else:\n            return False\n\n    if not _type_check_input_shape(test_input_shape):\n        raise ValueError(\"test_input_shape should be of type Tuple[int, ...] or \"\n                         \"Tuple[Tuple[int, ...], ...], but {} found\".format(type(test_input_shape)))\n\n    if not isinstance(test_input_shape[0], (tuple, list)):\n        test_input_shape = (test_input_shape,)\n\n    def _check_inputs_allclose(inputs, reference, atol):\n        for inp, ref in zip(inputs, reference):\n            if not torch.allclose(inp, ref, atol=atol):\n                return False\n        return True\n\n    def _pack_if_no_tuple(x):\n        if not isinstance(x, tuple):\n            return (x, )\n        return x\n\n    with torch.no_grad():\n        torch.manual_seed(random_seed)\n        test_inputs = tuple([torch.rand(shape, dtype=test_input_dtype) for shape in test_input_shape])\n        if any([torch.equal(torch.zeros_like(e), e) for e in test_inputs]):  # pragma: no cover\n            warnings.warn(\"Some inputs were detected to be all zeros, you might want to set a different random_seed.\")\n\n        if not _check_inputs_allclose(_pack_if_no_tuple(module_in.inverse(*_pack_if_no_tuple(module_in(*test_inputs)))), test_inputs, atol=atol):\n            return False\n\n        test_outputs = _pack_if_no_tuple(module_in(*test_inputs))\n        if any([torch.equal(torch.zeros_like(e), e) for e in test_outputs]):  # pragma: no cover\n            warnings.warn(\"Some outputs were detected to be all zeros, you might want to set a different random_seed.\")\n\n        if not _check_inputs_allclose(_pack_if_no_tuple(module_in(*_pack_if_no_tuple(module_in.inverse(*test_outputs)))), test_outputs, atol=atol):  # pragma: no cover\n            return False\n\n        test_reconstructed_inputs = _pack_if_no_tuple(module_in.inverse(*test_outputs))\n\n    def _test_shared(inputs, outputs, msg):\n        shared = set(inputs)\n        shared_outputs = set(outputs)\n        if len(inputs) != len(shared):  # pragma: no cover\n            warnings.warn(\"Some inputs (*x) share the same tensor, are you sure this is what you want? ({})\".format(msg))\n        if len(outputs) != len(shared_outputs):\n            warnings.warn(\"Some outputs (*y) share the same tensor, are you sure this is what you want? ({})\".format(msg))\n        if any([inp in shared for inp in shared_outputs]):\n            warnings.warn(\"Some inputs (*x) and outputs (*y) share the same tensor, this is typically not a \"\n                          \"good function to use with memcnn.InvertibleModuleWrapper as it might increase memory usage. \"\n                          \"E.g. an identity function. ({})\".format(msg))\n\n    _test_shared(test_inputs, test_outputs, msg=\"forward\")\n    _test_shared(test_reconstructed_inputs, test_outputs, msg=\"inverse\")\n\n    return True\n\n\n# We can't know if the run_fn will internally move some args to different devices,\n# which would require logic to preserve rng states for those devices as well.\n# We could paranoically stash and restore ALL the rng states for all visible devices,\n# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for\n# the device of all Tensor args.\n#\n# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?\n#\n# get_device_states and set_device_states cannot be imported from torch.utils.checkpoint, since it was not\n# present in older versions, so we include a copy here.\ndef get_device_states(*args):\n    # This will not error out if \"arg\" is a CPU tensor or a non-tensor type because\n    # the conditionals short-circuit.\n    fwd_gpu_devices = list(set(arg.get_device() for arg in args\n                               if isinstance(arg, torch.Tensor) and arg.is_cuda))\n\n    fwd_gpu_states = []\n    for device in fwd_gpu_devices:\n        with torch.cuda.device(device):\n            fwd_gpu_states.append(torch.cuda.get_rng_state())\n\n    return fwd_gpu_devices, fwd_gpu_states\n\n\ndef set_device_states(devices, states):\n    for device, state in zip(devices, states):\n        with torch.cuda.device(device):\n            torch.cuda.set_rng_state(state)\n"
  },
  {
    "path": "memcnn/models/tests/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/models/tests/test_amp.py",
    "content": "import pytest\nimport torch\nfrom torch import nn\nimport torch.optim as optim\n\nimport torchvision\nfrom torch.utils.checkpoint import checkpoint\nfrom torchvision.models.resnet import resnet18, BasicBlock\nimport torchvision.transforms as transforms\n\nimport memcnn\n\ntry:\n    from torch.cuda.amp import autocast, GradScaler\nexcept ModuleNotFoundError:\n    pass\n\n\nclass InvertibleBlock(nn.Module):\n    def __init__(self, block, keep_input, enabled=True):\n        super().__init__()\n        self.invertible_block = memcnn.InvertibleModuleWrapper(\n            fn=memcnn.AdditiveCoupling(block),\n            keep_input=keep_input,\n            keep_input_inverse=keep_input,\n            disable=not enabled,\n        )\n\n    def forward(self, x, inverse=False):\n        if inverse:\n            return self.invertible_block.inverse(x)\n        else:\n            return self.invertible_block(x)\n\n\nclass CheckPointBlock(nn.Module):\n    def __init__(self, block):\n        super().__init__()\n        self.invertible_module = memcnn.AdditiveCoupling(block)\n\n    def forward(self, x, inverse=False):\n        return checkpoint(self.invertible_module.forward, x)\n\n\n@pytest.mark.skipif(\n    condition=\"autocast\" not in locals(),\n    reason=\"torch.cuda.amp could not be found. torch version is < 1.6.\",\n)\n@pytest.mark.parametrize(\n    \"use_checkpointing, inv_enabled\", ((True, False), (False, True,), (False, False))\n)\n@pytest.mark.parametrize(\"amp_enabled\", (False, True))\ndef test_cuda_amp(tmp_path, inv_enabled, amp_enabled, use_checkpointing):\n    if not torch.cuda.is_available() and amp_enabled:\n        pytest.skip(\"This test requires a GPU to be available\")\n    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n    model = resnet18(num_classes=10)\n    transform = transforms.Compose(\n        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n    )\n    trainset = torchvision.datasets.CIFAR10(\n        root=tmp_path, train=True, download=True, transform=transform\n    )\n    trainloader = torch.utils.data.DataLoader(\n        trainset, batch_size=4, shuffle=True, num_workers=2\n    )\n\n    # Replace layer1\n    if not use_checkpointing:\n        model.layer1 = nn.Sequential(\n            InvertibleBlock(BasicBlock(32, 32), keep_input=False, enabled=inv_enabled),\n            InvertibleBlock(BasicBlock(32, 32), keep_input=False, enabled=inv_enabled),\n        )\n    else:\n        model.layer1 = nn.Sequential(\n            CheckPointBlock(BasicBlock(32, 32)), CheckPointBlock(BasicBlock(32, 32))\n        )\n\n    model.to(device)\n\n    criterion = nn.CrossEntropyLoss()\n    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n    scaler = GradScaler(enabled=amp_enabled)\n\n    for i, data in enumerate(trainloader):\n        inputs, labels = data\n        inputs, labels = inputs.to(device), labels.to(device)\n\n        optimizer.zero_grad()\n        with autocast(enabled=amp_enabled):\n            outputs = model(inputs)\n            loss = criterion(outputs, labels)\n            scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n        break\n"
  },
  {
    "path": "memcnn/models/tests/test_couplings.py",
    "content": "import torch\nimport torch.nn\nimport pytest\nimport copy\nimport warnings\n\nfrom memcnn import create_coupling, InvertibleModuleWrapper\nfrom memcnn.models.tests.test_revop import set_seeds\nfrom memcnn.models.tests.test_models import SubModule\nfrom memcnn.models.affine import AffineAdapterNaive, AffineBlock\nfrom memcnn.models.additive import AdditiveBlock\n\n\n@pytest.mark.parametrize('coupling', ['additive', 'affine'])\n@pytest.mark.parametrize('bwd', [False, True])\n@pytest.mark.parametrize('implementation', [-1, 0, 1])\ndef test_coupling_implementations_against_reference(coupling, bwd, implementation):\n    \"\"\"Test if similar gradients and weights results are obtained after similar training for the couplings\"\"\"\n    with warnings.catch_warnings():\n        warnings.simplefilter(action='ignore', category=DeprecationWarning)\n        for seed in range(10):\n            set_seeds(seed)\n\n            X = torch.rand(2, 4, 5, 5)\n\n            # define models and their copies\n            c1 = torch.nn.Conv2d(2, 2, 3, padding=1)\n            c2 = torch.nn.Conv2d(2, 2, 3, padding=1)\n            c1_2 = copy.deepcopy(c1)\n            c2_2 = copy.deepcopy(c2)\n\n            # are weights between models the same, but do they differ between convolutions?\n            assert torch.equal(c1.weight, c1_2.weight)\n            assert torch.equal(c2.weight, c2_2.weight)\n            assert torch.equal(c1.bias, c1_2.bias)\n            assert torch.equal(c2.bias, c2_2.bias)\n            assert not torch.equal(c1.weight, c2.weight)\n\n            # define optimizers\n            optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1)\n            optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1)\n            for e in [c1, c2, c1_2, c2_2]:\n                e.train()\n\n            # define an arbitrary reversible function and define graph for model 1\n            XX = X.detach().clone().requires_grad_()\n            coupling_fn = create_coupling(Fm=c1, Gm=c2, coupling=coupling, implementation_fwd=-1,\n                                          implementation_bwd=-1, adapter=AffineAdapterNaive)\n            Y = coupling_fn.inverse(XX) if bwd else coupling_fn.forward(XX)\n            loss = torch.mean(Y)\n\n            # define the reversible function without custom backprop and define graph for model 2\n            XX2 = X.detach().clone().requires_grad_()\n            coupling_fn2 = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=implementation,\n                                           implementation_bwd=implementation, adapter=AffineAdapterNaive)\n            Y2 = coupling_fn2.inverse(XX2) if bwd else coupling_fn2.forward(XX2)\n            loss2 = torch.mean(Y2)\n\n            # compute gradients manually\n            grads = torch.autograd.grad(loss2, (XX2, c1_2.weight, c2_2.weight, c1_2.bias, c2_2.bias), None, retain_graph=True)\n\n            # compute gradients using backward and perform optimization model 2\n            loss2.backward()\n            optim2.step()\n\n            # gradients computed manually match those of the .backward() pass\n            assert torch.equal(c1_2.weight.grad, grads[1])\n            assert torch.equal(c2_2.weight.grad, grads[2])\n            assert torch.equal(c1_2.bias.grad, grads[3])\n            assert torch.equal(c2_2.bias.grad, grads[4])\n\n            # weights differ after training a single model?\n            assert not torch.equal(c1.weight, c1_2.weight)\n            assert not torch.equal(c2.weight, c2_2.weight)\n            assert not torch.equal(c1.bias, c1_2.bias)\n            assert not torch.equal(c2.bias, c2_2.bias)\n\n            # compute gradients and perform optimization model 1\n            loss.backward()\n            optim1.step()\n\n            # weights are approximately the same after training both models?\n            assert torch.allclose(c1.weight.detach(), c1_2.weight.detach())\n            assert torch.allclose(c2.weight.detach(), c2_2.weight.detach())\n            assert torch.allclose(c1.bias.detach(), c1_2.bias.detach())\n            assert torch.allclose(c2.bias.detach(), c2_2.bias.detach())\n\n            # gradients are approximately the same after training both models?\n            assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach())\n            assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach())\n            assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach())\n            assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach())\n\n            fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)\n            Yout = fn.inverse(XX) if bwd else fn.forward(XX)\n            loss = torch.mean(Yout)\n            loss.backward()\n            assert XX.storage().size() > 0\n\n            fn2 = InvertibleModuleWrapper(fn=coupling_fn2, keep_input=False, keep_input_inverse=False)\n            Yout2 = fn2.inverse(XX2) if bwd else fn2.forward(XX2)\n            loss = torch.mean(Yout2)\n            loss.backward()\n            assert XX2.storage().size() > 0\n\n\ndef test_legacy_additive_coupling():\n    with warnings.catch_warnings():\n        warnings.simplefilter(action='ignore', category=DeprecationWarning)\n        AdditiveBlock(Fm=SubModule())\n\n\ndef test_legacy_affine_coupling():\n    with warnings.catch_warnings():\n        warnings.simplefilter(action='ignore', category=DeprecationWarning)\n        AffineBlock(Fm=SubModule())\n"
  },
  {
    "path": "memcnn/models/tests/test_is_invertible_module.py",
    "content": "import pytest\nimport torch\n\nfrom memcnn import is_invertible_module, InvertibleModuleWrapper, AdditiveCoupling\nfrom memcnn.models.tests.test_models import IdentityInverse, MultiSharedOutputs, SubModule\n\n\ndef test_is_invertible_module_with_invalid_inverse():\n    fn = IdentityInverse(multiply_inverse=True)\n    with torch.no_grad():\n        fn.factor.zero_()\n    assert not is_invertible_module(fn, test_input_shape=(12, 12))\n\n\n@pytest.mark.parametrize(\"random_seed\", [1, 42, 900000])\ndef test_is_invertible_module_random_seeds(random_seed):\n    fn = IdentityInverse(multiply_forward=True, multiply_inverse=True)\n    assert is_invertible_module(fn, test_input_shape=(1, ), random_seed=random_seed)\n\n\ndef test_is_invertible_module_shared_outputs():\n    fnb = MultiSharedOutputs()\n    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()\n    with pytest.warns(UserWarning):\n        assert is_invertible_module(fnb, test_input_shape=(X.shape,), atol=1e-6)\n\n\ndef test_is_invertible_module_shared_tensors():\n    fn = IdentityInverse()\n    rm = InvertibleModuleWrapper(fn=fn, keep_input=True, keep_input_inverse=True)\n    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()\n    with pytest.warns(UserWarning):\n        assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)\n    rm.forward(X)\n    fn.multiply_forward = True\n    rm.forward(X)\n    with pytest.warns(UserWarning):\n        assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)\n    rm.inverse(X)\n    fn.multiply_inverse = True\n    rm.inverse(X)\n    assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)\n\n\ndef test_is_invertible_module():\n    X = torch.zeros(1, 10, 10, 10)\n    assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)),\n                                    test_input_shape=X.shape)\n    fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1)\n    assert is_invertible_module(fn, test_input_shape=X.shape)\n    class FakeInverse(torch.nn.Module):\n        def forward(self, x):\n            return x * 4\n\n        def inverse(self, y):\n            return y * 8\n    assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)\n\n\ndef test_is_invertible_module_wrapped():\n    X = torch.zeros(1, 10, 10, 10)\n    assert not is_invertible_module(InvertibleModuleWrapper(torch.nn.Conv2d(10, 10, kernel_size=(1, 1))),\n                                    test_input_shape=X.shape)\n    fn = InvertibleModuleWrapper(AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1))\n    assert is_invertible_module(fn, test_input_shape=X.shape)\n    class FakeInverse(torch.nn.Module):\n        def forward(self, x):\n            return x * 4\n\n        def inverse(self, y):\n            return y * 8\n    assert not is_invertible_module(InvertibleModuleWrapper(FakeInverse()), test_input_shape=X.shape)\n\n\n@pytest.mark.parametrize(\"input_shape\", (\n    \"string\",\n    (2.3, 1.4),\n    None,\n    True,\n    ((1, 3, ), (12.4)),\n    ((1, 3, ), False)\n))\ndef test_is_invertible_module_type_check_input_shapes(input_shape):\n    with pytest.raises(ValueError):\n        is_invertible_module(module_in=IdentityInverse(multiply_forward=True, multiply_inverse=True), test_input_shape=input_shape)\n"
  },
  {
    "path": "memcnn/models/tests/test_memory_saving.py",
    "content": "import pytest\nimport gc\nimport numpy as np\nimport torch\nimport torch.nn\nfrom memcnn.models.tests.test_models import SubModule, SubModuleStack\n\n\n@pytest.mark.parametrize('coupling', ['additive', 'affine'])\n@pytest.mark.parametrize('keep_input', [True, False])\n@pytest.mark.parametrize('device', ['cpu', 'cuda'])\ndef test_memory_saving_invertible_model_wrapper(device, coupling, keep_input):\n    \"\"\"Test memory saving of the invertible model wrapper\n\n    * tests fitting a large number of images by creating a deep network requiring large\n      intermediate feature maps for training\n\n    * keep_input = False should use less memory than keep_input = True on both GPU and CPU RAM\n\n    * input size in bytes:            np.prod((2, 10, 10, 10)) * 4 / 1024.0 =  7.8125 kB\n      for a depth=5 this yields                                  7.8125 * 5 = 39.0625 kB\n\n    \"\"\"\n\n    if device == 'cpu':\n        pytest.skip('Unreliable metrics, should be fixed.')\n\n    if device == 'cuda' and not torch.cuda.is_available():\n        pytest.skip('This test requires a GPU to be available')\n\n    gc.disable()\n    gc.collect()\n\n    with torch.set_grad_enabled(True):\n        dims = [2, 10, 10, 10]\n        depth = 5\n\n        xx = torch.rand(*dims, device=device, dtype=torch.float32).requires_grad_()\n        ytarget = torch.rand(*dims, device=device, dtype=torch.float32)\n\n        # same convolution test\n        network = SubModuleStack(SubModule(in_filters=5, out_filters=5), depth=depth, keep_input=keep_input, coupling=coupling,\n                                 implementation_fwd=-1, implementation_bwd=-1)\n        network.to(device)\n        network.train()\n        network.zero_grad()\n        optim = torch.optim.RMSprop(network.parameters())\n        optim.zero_grad()\n        mem_start = 0 if not device == 'cuda' else \\\n            torch.cuda.memory_allocated() / float(1024 ** 2)\n\n        y = network(xx)\n        gc.collect()\n        mem_after_forward = torch.cuda.memory_allocated() / float(1024 ** 2)\n        loss = torch.nn.MSELoss()(y, ytarget)\n        optim.zero_grad()\n        loss.backward()\n        optim.step()\n        gc.collect()\n        # mem_after_backward = torch.cuda.memory_allocated() / float(1024 ** 2)\n        gc.enable()\n\n        memuse = float(np.prod(dims + [depth, 4, ])) / float(1024 ** 2)\n\n        measured_memuse = mem_after_forward - mem_start\n        if keep_input:\n            assert measured_memuse >= memuse\n        else:\n            assert measured_memuse < 1\n        # assert math.floor(mem_after_backward - mem_start) >= 9\n"
  },
  {
    "path": "memcnn/models/tests/test_models.py",
    "content": "import torch\nimport torch.nn\n\nfrom memcnn import create_coupling, InvertibleModuleWrapper\n\n\nclass MultiplicationInverse(torch.nn.Module):\n    def __init__(self, factor=2):\n        super(MultiplicationInverse, self).__init__()\n        self.factor = torch.nn.Parameter(torch.ones(1) * factor)\n\n    def forward(self, x):\n        return x * self.factor\n\n    def inverse(self, y):\n        return y / self.factor\n\n\nclass IdentityInverse(torch.nn.Module):\n    def __init__(self, multiply_forward=False, multiply_inverse=False):\n        super(IdentityInverse, self).__init__()\n        self.factor = torch.nn.Parameter(torch.ones(1))\n        self.multiply_forward = multiply_forward\n        self.multiply_inverse = multiply_inverse\n\n    def forward(self, x):\n        if self.multiply_forward:\n            return x * self.factor\n        else:\n            return x\n\n    def inverse(self, y):\n        if self.multiply_inverse:\n            return y * self.factor\n        else:\n            return y\n\n\nclass MultiSharedOutputs(torch.nn.Module):\n    # pylint: disable=R0201\n    def forward(self, x):\n        y = x * x\n        return y, y\n\n    # pylint: disable=R0201\n    def inverse(self, y, y2):\n        x = torch.max(torch.sqrt(y), torch.sqrt(y2))\n        return x\n\n\nclass SubModule(torch.nn.Module):\n    def __init__(self, in_filters=5, out_filters=5):\n        super(SubModule, self).__init__()\n        self.bn = torch.nn.BatchNorm2d(out_filters)\n        self.conv = torch.nn.Conv2d(in_filters, out_filters, (3, 3), padding=1)\n\n    def forward(self, x):\n        return self.bn(self.conv(x))\n\n\nclass SubModuleStack(torch.nn.Module):\n    def __init__(self, Gm, coupling='additive', depth=10, implementation_fwd=-1, implementation_bwd=-1,\n                 keep_input=False, adapter=None, num_bwd_passes=1):\n        super(SubModuleStack, self).__init__()\n        fn = create_coupling(Fm=Gm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter)\n        self.stack = torch.nn.ModuleList(\n            [InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input, num_bwd_passes=num_bwd_passes) for _ in range(depth)]\n        )\n\n    def forward(self, x):\n        for rev_module in self.stack:\n            x = rev_module.forward(x)\n        return x\n\n    def inverse(self, y):\n        for rev_module in reversed(self.stack):\n            y = rev_module.inverse(y)\n        return y\n\n\nclass SplitChannels(torch.nn.Module):\n    def __init__(self, split_location):\n        self.split_location = split_location\n        super(SplitChannels, self).__init__()\n\n    def forward(self, x):\n        return (x[:, :self.split_location, :].clone(),\n                x[:, self.split_location:, :].clone())\n\n    # pylint: disable=R0201\n    def inverse(self, x, y):\n        return torch.cat([x, y], dim=1)\n\n\nclass ConcatenateChannels(torch.nn.Module):\n    def __init__(self, split_location):\n        self.split_location = split_location\n        super(ConcatenateChannels, self).__init__()\n\n    # pylint: disable=R0201\n    def forward(self, x, y):\n        return torch.cat([x, y], dim=1)\n\n    def inverse(self, x):\n        return (x[:, :self.split_location, :].clone(),\n                x[:, self.split_location:, :].clone())\n"
  },
  {
    "path": "memcnn/models/tests/test_multi.py",
    "content": "import pytest\nimport torch\nfrom memcnn.models.revop import InvertibleModuleWrapper, is_invertible_module\nfrom memcnn.models.tests.test_models import SplitChannels, ConcatenateChannels\n\n\n@pytest.mark.parametrize('disable', [True, False])\ndef test_multi(disable):\n    split = InvertibleModuleWrapper(SplitChannels(2), disable = disable)\n    concat = InvertibleModuleWrapper(ConcatenateChannels(2), disable = disable)\n\n    assert is_invertible_module(split, test_input_shape=(1, 3, 32, 32))\n    assert is_invertible_module(concat, test_input_shape=((1, 2, 32, 32), (1, 1, 32, 32)))\n\n    conv_a = torch.nn.Conv2d(2, 2, 3)\n    conv_b = torch.nn.Conv2d(1, 1, 3)\n\n    x = torch.rand(1, 3, 32, 32)\n    x.requires_grad = True\n\n    a, b = split(x)\n    a, b = conv_a(a), conv_b(b)\n    y = concat(a, b)\n    loss = torch.sum(y)\n    loss.backward()\n"
  },
  {
    "path": "memcnn/models/tests/test_resnet.py",
    "content": "import pytest\nimport torch\nfrom memcnn.models.resnet import ResNet, BasicBlock, Bottleneck, RevBasicBlock, RevBottleneck\n\n\n@pytest.mark.parametrize('block,batch_norm_fix', [(BasicBlock, True), (Bottleneck, False), (RevBasicBlock, False), (RevBottleneck, True)])\ndef test_resnet(block, batch_norm_fix):\n    model = ResNet(block, [2, 2, 2, 2], num_classes=2, channels_per_layer=None,\n                   init_max_pool=True, batch_norm_fix=batch_norm_fix, strides=None)\n    model.eval()\n    with torch.no_grad():\n        x = torch.ones(2, 3, 32, 32)\n        model.forward(x)\n"
  },
  {
    "path": "memcnn/models/tests/test_revop.py",
    "content": "import warnings\nimport pytest\nimport random\nimport torch\nimport torch.nn\nimport numpy as np\nimport copy\nfrom memcnn.models.affine import AffineAdapterNaive, AffineAdapterSigmoid, AffineCoupling\nfrom memcnn.models.revop import InvertibleModuleWrapper, ReversibleBlock, create_coupling, \\\n     is_invertible_module, get_device_states, set_device_states\nfrom memcnn.models.additive import AdditiveCoupling\nfrom memcnn.models.tests.test_models import MultiplicationInverse, SubModule, SubModuleStack\n\n\ndef set_seeds(seed):\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n\n\ndef is_memory_cleared(var, isclear, shape):\n    if isclear:\n        return var.storage().size() == 0\n    else:\n        return var.storage().size() > 0 and var.shape == shape\n\n\n@pytest.mark.parametrize('device', ['cpu', 'cuda'])\n@pytest.mark.parametrize('enabled', [True, False])\ndef test_get_set_device_states(device, enabled):\n    shape = (1, 1, 10, 10)\n    if not torch.cuda.is_available() and device == 'cuda':\n        pytest.skip('This test requires a GPU to be available')\n    X = torch.ones(shape, device=device)\n    devices, states = get_device_states(X)\n    assert len(states) == (1 if device == 'cuda' else 0)\n    assert len(devices) == (1 if device == 'cuda' else 0)\n    cpu_rng_state = torch.get_rng_state()\n    Y = X * torch.rand(shape, device=device)\n    with torch.random.fork_rng(devices=devices, enabled=True):\n        if enabled:\n            if device == 'cpu':\n                torch.set_rng_state(cpu_rng_state)\n            else:\n                set_device_states(devices=devices, states=states)\n        Y2 = X * torch.rand(shape, device=device)\n    assert torch.equal(Y, Y2) == enabled\n\n\n@pytest.mark.parametrize('coupling', ['additive', 'affine'])\ndef test_reversible_block_notimplemented(coupling):\n    fm = torch.nn.Conv2d(10, 10, (3, 3), padding=1)\n    X = torch.zeros(1, 20, 10, 10)\n    with pytest.raises(NotImplementedError):\n        with warnings.catch_warnings():\n            warnings.simplefilter(action='ignore', category=DeprecationWarning)\n            f = ReversibleBlock(fm, coupling=coupling, implementation_bwd=0, implementation_fwd=-2,\n                                      adapter=AffineAdapterNaive)\n            assert isinstance(f, InvertibleModuleWrapper)\n            f.forward(X)\n    with pytest.raises(NotImplementedError):\n        with warnings.catch_warnings():\n            warnings.simplefilter(action='ignore', category=DeprecationWarning)\n            f = ReversibleBlock(fm, coupling=coupling, implementation_bwd=-2, implementation_fwd=0,\n                                      adapter=AffineAdapterNaive)\n            assert isinstance(f, InvertibleModuleWrapper)\n            f.inverse(X)\n    with pytest.raises(NotImplementedError):\n        with warnings.catch_warnings():\n            warnings.simplefilter(action='ignore', category=DeprecationWarning)\n            ReversibleBlock(fm, coupling='unknown', implementation_bwd=-2, implementation_fwd=0,\n                                  adapter=AffineAdapterNaive)\n\n\n@pytest.mark.parametrize('fn', [\n    AdditiveCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1),\n    AffineCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive),\n    AffineCoupling(Fm=SubModule(out_filters=10), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterSigmoid),\n    MultiplicationInverse()\n])\n@pytest.mark.parametrize('bwd', [False, True])\n@pytest.mark.parametrize('keep_input', [False, True])\n@pytest.mark.parametrize('keep_input_inverse', [False, True])\n@pytest.mark.parametrize('preserve_rng_state', [False, True])\ndef test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse, preserve_rng_state):\n    \"\"\"InvertibleModuleWrapper tests for the memory saving forward and backward passes\n\n    * test inversion Y = RB(X) and X = RB.inverse(Y)\n    * test training the block for a single step and compare weights for implementations: 0, 1\n    * test automatic discard of input X and its retrieval after the backward pass\n    * test usage of BN to identify non-contiguous memory blocks\n\n    \"\"\"\n    for seed in range(10):\n        set_seeds(seed)\n        dims = (2, 10, 8, 8)\n        data = torch.rand(*dims, dtype=torch.float32)\n        target_data = torch.rand(*dims, dtype=torch.float32)\n\n        assert is_invertible_module(fn, test_input_shape=data.shape, atol=1e-4)\n\n        # test with zero padded convolution\n        with torch.set_grad_enabled(True):\n            X = data.clone().requires_grad_()\n\n            Ytarget = target_data.clone()\n\n            Xshape = X.shape\n\n            rb = InvertibleModuleWrapper(fn=fn, keep_input=keep_input,\n                                         keep_input_inverse=keep_input_inverse,\n                                         preserve_rng_state=preserve_rng_state)\n            s_grad = [p.detach().clone() for p in rb.parameters()]\n\n            rb.train()\n            rb.zero_grad()\n\n            optim = torch.optim.RMSprop(rb.parameters())\n            optim.zero_grad()\n            if not bwd:\n                Xin = X.clone().requires_grad_()\n                Y = rb(Xin)\n                Yrev = Y.detach().clone().requires_grad_()\n                Xinv = rb.inverse(Yrev)\n            else:\n                Xin = X.clone().requires_grad_()\n                Y = rb.inverse(Xin)\n                Yrev = Y.detach().clone().requires_grad_()\n                Xinv = rb(Yrev)\n            loss = torch.nn.MSELoss()(Y, Ytarget)\n\n            # has input been retained/discarded after forward (and backward) passes?\n\n            if not bwd:\n                assert is_memory_cleared(Yrev, not keep_input_inverse, Xshape)\n                assert is_memory_cleared(Xin, not keep_input, Xshape)\n            else:\n                assert is_memory_cleared(Xin, not keep_input_inverse, Xshape)\n                assert is_memory_cleared(Yrev, not keep_input, Xshape)\n\n            optim.zero_grad()\n\n            loss.backward()\n            optim.step()\n\n            assert Y.shape == Xshape\n            assert X.detach().shape == data.shape\n            assert torch.allclose(X.detach(), data, atol=1e-06)\n            assert torch.allclose(X.detach(), Xinv.detach(), atol=1e-04)  # Model is now trained and will differ\n            grads = [p.detach().clone() for p in rb.parameters()]\n\n            assert not torch.allclose(grads[0], s_grad[0])\n\n\n@pytest.mark.parametrize('coupling,adapter', [('additive', None),\n                                              ('affine', AffineAdapterNaive),\n                                              ('affine', AffineAdapterSigmoid)])\ndef test_chained_invertible_module_wrapper(coupling, adapter):\n    set_seeds(42)\n    dims = (2, 10, 8, 8)\n    data = torch.rand(*dims, dtype=torch.float32)\n    target_data = torch.rand(*dims, dtype=torch.float32)\n    with torch.set_grad_enabled(True):\n        X = data.clone().requires_grad_()\n        Ytarget = target_data.clone()\n\n        Gm = SubModule(in_filters=5, out_filters=5 if coupling == 'additive' or adapter is AffineAdapterNaive else 10)\n        rb = SubModuleStack(Gm, coupling=coupling, depth=2, keep_input=False, adapter=adapter, implementation_bwd=-1, implementation_fwd=-1)\n        rb.train()\n        optim = torch.optim.RMSprop(rb.parameters())\n\n        rb.zero_grad()\n\n        optim.zero_grad()\n\n        Xin = X.clone()\n        Y = rb(Xin)\n\n        loss = torch.nn.MSELoss()(Y, Ytarget)\n\n        loss.backward()\n        optim.step()\n\n    assert not torch.isnan(loss)\n\n\ndef test_chained_invertible_module_wrapper_shared_fwd_and_bwd_train_passes():\n    set_seeds(42)\n    Gm = SubModule(in_filters=5, out_filters=5)\n    rb_temp = SubModuleStack(Gm=Gm, coupling='additive', depth=5, keep_input=True, adapter=None, implementation_bwd=-1,\n                             implementation_fwd=-1)\n    optim = torch.optim.SGD(rb_temp.parameters(), lr=0.01)\n\n    initial_params = [p.detach().clone() for p in rb_temp.parameters()]\n    initial_state = copy.deepcopy(rb_temp.state_dict())\n    initial_optim_state = copy.deepcopy(optim.state_dict())\n\n    dims = (2, 10, 8, 8)\n    data = torch.rand(*dims, dtype=torch.float32)\n    target_data = torch.rand(*dims, dtype=torch.float32)\n\n    forward_outputs = []\n    inverse_outputs = []\n    for i in range(10):\n\n        is_forward_pass = i % 2 == 0\n        set_seeds(42)\n        rb = SubModuleStack(Gm=Gm, coupling='additive', depth=5, keep_input=True,\n                            adapter=None, implementation_bwd=-1,\n                            implementation_fwd=-1, num_bwd_passes=2)\n        rb.train()\n        with torch.no_grad():\n            for (name, p), p_initial in zip(rb.named_parameters(), initial_params):\n                p.set_(p_initial)\n\n        rb.load_state_dict(initial_state)\n        optim = torch.optim.SGD(rb_temp.parameters(), lr=0.01)\n        optim.load_state_dict(initial_optim_state)\n\n        with torch.set_grad_enabled(True):\n            X = data.detach().clone().requires_grad_()\n            Ytarget = target_data.detach().clone()\n\n            optim.zero_grad()\n\n            if is_forward_pass:\n                Y = rb(X)\n                Xinv = rb.inverse(Y)\n                Xinv2 = rb.inverse(Y)\n                Xinv3 = rb.inverse(Y)\n            else:\n                Y = rb.inverse(X)\n                Xinv = rb(Y)\n                Xinv2 = rb(Y)\n                Xinv3 = rb(Y)\n\n            for item in [Xinv, Xinv2, Xinv3]:\n                assert torch.allclose(X, item, atol=1e-04)\n\n            loss = torch.nn.MSELoss()(Xinv, Ytarget)\n            assert not torch.isnan(loss)\n\n            assert Xinv2.grad is None\n            assert Xinv3.grad is None\n\n            loss.backward()\n\n            assert Y.grad is not None\n            assert Xinv.grad is not None\n            assert Xinv2.grad is None\n            assert Xinv3.grad is None\n\n            loss2 = torch.nn.MSELoss()(Xinv2, Ytarget)\n            assert not torch.isnan(loss2)\n\n            loss2.backward()\n\n            assert Xinv2.grad is not None\n\n            optim.step()\n\n            if is_forward_pass:\n                forward_outputs.append(Y.detach().clone())\n            else:\n                inverse_outputs.append(Y.detach().clone())\n\n    for i in range(4):\n        assert torch.allclose(forward_outputs[-1], forward_outputs[i], atol=1e-06)\n        assert torch.allclose(inverse_outputs[-1], inverse_outputs[i], atol=1e-06)\n\n\n@pytest.mark.parametrize(\"inverted\", [False, True])\ndef test_invertible_module_wrapper_disabled_versus_enabled(inverted):\n    set_seeds(42)\n    Gm = SubModule(in_filters=5, out_filters=5)\n\n    coupling_fn = create_coupling(Fm=Gm, Gm=Gm, coupling='additive', implementation_fwd=-1,\n                                  implementation_bwd=-1)\n    rb = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)\n    rb2 = InvertibleModuleWrapper(fn=copy.deepcopy(coupling_fn), keep_input=False, keep_input_inverse=False)\n    rb.eval()\n    rb2.eval()\n    rb2.disable = True\n    with torch.no_grad():\n        dims = (2, 10, 8, 8)\n        data = torch.rand(*dims, dtype=torch.float32)\n        X, X2 = data.clone().detach().requires_grad_(), data.clone().detach().requires_grad_()\n        if not inverted:\n            Y = rb(X)\n            Y2 = rb2(X2)\n        else:\n            Y = rb.inverse(X)\n            Y2 = rb2.inverse(X2)\n\n        assert torch.allclose(Y, Y2)\n\n        assert is_memory_cleared(X, True, dims)\n        assert is_memory_cleared(X2, False, dims)\n\n\n@pytest.mark.parametrize('coupling', ['additive', 'affine'])\ndef test_invertible_module_wrapper_simple_inverse(coupling):\n    \"\"\"InvertibleModuleWrapper inverse test\"\"\"\n    for seed in range(10):\n        set_seeds(seed)\n        # define some data\n        X = torch.rand(2, 4, 5, 5).requires_grad_()\n\n        # define an arbitrary reversible function\n        coupling_fn = create_coupling(Fm=torch.nn.Conv2d(2, 2, 3, padding=1), coupling=coupling, implementation_fwd=-1,\n                                      implementation_bwd=-1, adapter=AffineAdapterNaive)\n        fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)\n\n        # compute output\n        Y = fn.forward(X.clone())\n\n        # compute input from output\n        X2 = fn.inverse(Y)\n\n        # check that the inverted output and the original input are approximately similar\n        assert torch.allclose(X2.detach(), X.detach(), atol=1e-06)\n\n\n@pytest.mark.parametrize('coupling', ['additive', 'affine'])\ndef test_normal_vs_invertible_module_wrapper(coupling):\n    \"\"\"InvertibleModuleWrapper test if similar gradients and weights results are obtained after similar training\"\"\"\n    for seed in range(10):\n        set_seeds(seed)\n\n        X = torch.rand(2, 4, 5, 5)\n\n        # define models and their copies\n        c1 = torch.nn.Conv2d(2, 2, 3, padding=1)\n        c2 = torch.nn.Conv2d(2, 2, 3, padding=1)\n        c1_2 = copy.deepcopy(c1)\n        c2_2 = copy.deepcopy(c2)\n\n        # are weights between models the same, but do they differ between convolutions?\n        assert torch.equal(c1.weight, c1_2.weight)\n        assert torch.equal(c2.weight, c2_2.weight)\n        assert torch.equal(c1.bias, c1_2.bias)\n        assert torch.equal(c2.bias, c2_2.bias)\n        assert not torch.equal(c1.weight, c2.weight)\n\n        # define optimizers\n        optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1)\n        optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1)\n        for e in [c1, c2, c1_2, c2_2]:\n            e.train()\n\n        # define an arbitrary reversible function and define graph for model 1\n        Xin = X.clone().requires_grad_()\n        coupling_fn = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=-1,\n                                      implementation_bwd=-1, adapter=AffineAdapterNaive)\n        fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)\n\n        Y = fn.forward(Xin)\n        loss2 = torch.mean(Y)\n\n        # define the reversible function without custom backprop and define graph for model 2\n        XX = X.clone().detach().requires_grad_()\n        x1, x2 = torch.chunk(XX, 2, dim=1)\n        if coupling == 'additive':\n            y1 = x1 + c1.forward(x2)\n            y2 = x2 + c2.forward(y1)\n        elif coupling == 'affine':\n            fmr2 = c1.forward(x2)\n            fmr1 = torch.exp(fmr2)\n            y1 = (x1 * fmr1) + fmr2\n            gmr2 = c2.forward(y1)\n            gmr1 = torch.exp(gmr2)\n            y2 = (x2 * gmr1) + gmr2\n        else:\n            raise NotImplementedError()\n        YY = torch.cat([y1, y2], dim=1)\n\n        loss = torch.mean(YY)\n\n        # compute gradients manually\n        grads = torch.autograd.grad(loss, (XX, c1.weight, c2.weight, c1.bias, c2.bias), None, retain_graph=True)\n\n        # compute gradients and perform optimization model 2\n        loss.backward()\n        optim1.step()\n\n        # gradients computed manually match those of the .backward() pass\n        assert torch.equal(c1.weight.grad, grads[1])\n        assert torch.equal(c2.weight.grad, grads[2])\n        assert torch.equal(c1.bias.grad, grads[3])\n        assert torch.equal(c2.bias.grad, grads[4])\n\n        # weights differ after training a single model?\n        assert not torch.equal(c1.weight, c1_2.weight)\n        assert not torch.equal(c2.weight, c2_2.weight)\n        assert not torch.equal(c1.bias, c1_2.bias)\n        assert not torch.equal(c2.bias, c2_2.bias)\n\n        # compute gradients and perform optimization model 1\n        loss2.backward()\n        optim2.step()\n\n        # input is contiguous tests\n        assert Xin.is_contiguous()\n        assert Y.is_contiguous()\n\n        # weights are approximately the same after training both models?\n        assert torch.allclose(c1.weight.detach(), c1_2.weight.detach())\n        assert torch.allclose(c2.weight.detach(), c2_2.weight.detach())\n        assert torch.allclose(c1.bias.detach(), c1_2.bias.detach())\n        assert torch.allclose(c2.bias.detach(), c2_2.bias.detach())\n\n        # gradients are approximately the same after training both models?\n        assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach())\n        assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach())\n        assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach())\n        assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach())\n"
  },
  {
    "path": "memcnn/models/tests/test_split_dim.py",
    "content": "\nimport pytest\nimport torch\n\nfrom memcnn import AdditiveCoupling, AffineAdapterNaive, AffineCoupling\n\n\nclass Check(torch.nn.Module):\n    def __init__(self, dim, target_size):\n        super(Check, self).__init__()\n        self.dim = dim\n        self.target_size = target_size\n\n    def forward(self, fn_input):\n        assert fn_input.size(self.dim) == self.target_size\n        return fn_input\n\n\n@pytest.mark.parametrize('dimension', [None, 0, 1, 2])\n@pytest.mark.parametrize('coupling', [AdditiveCoupling, AffineCoupling])\n@pytest.mark.parametrize('input_size', [(2, 2, 2), (2, 4, 8, 12)])\ndef test_split_dim(dimension, coupling, input_size):\n    dim = 1 if dimension is None else dimension\n    module = Check(dim, input_size[dim] // 2)\n    coupling_args = dict(adapter=AffineAdapterNaive) if coupling.__name__ == 'AffineCoupling' else dict()\n    if dimension is not None:\n        coupling_args[\"split_dim\"] = dimension\n    model = coupling(module, **coupling_args)\n    inp = torch.randn(input_size, requires_grad=False)\n    output = model(inp)\n    assert inp.shape == output.shape\n"
  },
  {
    "path": "memcnn/train.py",
    "content": "import argparse\nimport os\nimport logging\nimport torch\n\nfrom memcnn.config import Config\nfrom memcnn.experiment.manager import ExperimentManager\nfrom memcnn.experiment.factory import load_experiment_config, experiment_config_parser\n\nimport memcnn.utils.log\n\n\nlogger = logging.getLogger('train')\n\n\ndef run_experiment(experiment_tags, data_dir, results_dir, start_fresh=False, use_cuda=False, workers=None,\n                   experiments_file=None, *args, **kwargs):\n    if not os.path.exists(data_dir):\n        raise RuntimeError('Cannot find data_dir directory: {}'.format(data_dir))\n\n    if not os.path.exists(results_dir):\n        raise RuntimeError('Cannot find results_dir directory: {}'.format(results_dir))\n\n    cfg = load_experiment_config(experiments_file, experiment_tags)\n    logger.info(cfg)\n\n    model, optimizer, trainer, trainer_params = experiment_config_parser(cfg, workers=workers, data_dir=data_dir)\n\n    experiment_dir = os.path.join(results_dir, '_'.join(experiment_tags))\n    manager = ExperimentManager(experiment_dir, model, optimizer)\n    if start_fresh:\n        logger.info('Starting fresh option enabled. Clearing all previous results...')\n        manager.delete_dirs()\n    manager.make_dirs()\n\n    if use_cuda:\n        manager.model = manager.model.cuda()\n        import torch.backends.cudnn as cudnn\n        cudnn.benchmark = True\n\n    last_iter = manager.get_last_model_iteration()\n    if last_iter > 0:\n        logger.info('Continue experiment from iteration: {}'.format(last_iter))\n        manager.load_train_state(last_iter)\n\n    trainer_params.update(kwargs)\n\n    trainer(manager, start_iter=last_iter, use_cuda=use_cuda, *args, **trainer_params)\n\n\ndef main(data_dir, results_dir):\n    # setup logging\n    memcnn.utils.log.setup(True)\n\n    # specify defaults for arguments\n    use_cuda = torch.cuda.is_available()\n    workers = 16\n    experiments_file = os.path.join(os.path.dirname(__file__), 'config', 'experiments.json')\n    start_fresh = False\n\n    # parse arguments\n    parser = argparse.ArgumentParser(description='Run memcnn experiments.')\n    parser.add_argument('experiment_tags', type=str, nargs='+',\n                        help='Experiment tags to run and combine from the experiment config file')\n    parser.add_argument('--workers', dest='workers', type=int, default=workers,\n                        help='Number of workers for data loading (Default: {})'.format(workers))\n    parser.add_argument('--results-dir', dest='results_dir', type=str, default=results_dir,\n                        help='Directory for storing results (Default: {})'.format(results_dir))\n    parser.add_argument('--data-dir', dest='data_dir', type=str, default=data_dir,\n                        help='Directory for input data (Default: {})'.format(data_dir))\n    parser.add_argument('--experiments-file', dest='experiments_file', type=str, default=experiments_file,\n                        help='Experiments file (Default: {})'.format(experiments_file))\n    parser.add_argument('--fresh', dest='start_fresh', action='store_true', default=start_fresh,\n                        help='Start with fresh experiment, clears all previous results (Default: {})'\n                        .format(start_fresh))\n    parser.add_argument('--no-cuda', dest='use_cuda', action='store_false', default=use_cuda,\n                        help='Always disables GPU use (Default: use when available)')\n    args = parser.parse_args()\n\n    if not use_cuda:\n        logger.warning('CUDA is not available in the current configuration!!!')\n\n    if not args.use_cuda:\n        logger.warning('CUDA is disabled!!!')\n\n    # run experiment given arguments\n    run_experiment(\n        args.experiment_tags,\n        args.data_dir,\n        args.results_dir,\n        start_fresh=args.start_fresh,\n        experiments_file=args.experiments_file,\n        use_cuda=args.use_cuda, workers=args.workers)\n\n\nif __name__ == '__main__':  # pragma: no cover\n    config_fname = Config.get_filename()\n    if not os.path.exists(config_fname) or not 'data_dir' in Config() or not 'results_dir' in Config():\n        print('The configuration file was not set correctly.\\n')\n        print('Please create a configuration file (json) at:\\n {}\\n'.format(config_fname))\n        print('The configuration file should be formatted as follows:\\n\\n'\n              '{\\n'\n              '    \"data_dir\": \"/home/user/data\",\\n'\n              '    \"results_dir\": \"/home/user/experiments\"\\n'\n              '}\\n')\n        print('data_dir    : location for storing the input training datasets')\n        print('results_dir : location for storing the experiment files during training')\n    else:\n        main(data_dir=Config()['data_dir'],\n             results_dir=Config()['results_dir'])\n"
  },
  {
    "path": "memcnn/trainers/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/trainers/classification.py",
    "content": "import time\nimport logging\nimport torch\nimport numpy as np\nfrom memcnn.utils.stats import AverageMeter, accuracy\nfrom memcnn.utils.log import SummaryWriter\n\nlogger = logging.getLogger('trainer')\n\n\ndef validate(model, ceriterion, val_loader, device):\n    \"\"\"validation sub-loop\"\"\"\n    model.eval()\n\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n\n    end = time.time()\n    with torch.no_grad():\n        for x, label in val_loader:\n            x, label = x.to(device), label.to(device)\n            vx, vl = x, label\n\n            score = model(vx)\n            loss = ceriterion(score, vl)\n            prec1 = accuracy(score.data, label)\n\n            losses.update(loss.item(), x.size(0))\n            top1.update(prec1[0][0], x.size(0))\n\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n    logger.info('Test: [{0}/{0}]\\t'\n                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'.format(len(val_loader),\n                                                                  batch_time=batch_time, loss=losses, top1=top1))\n\n    return top1.avg, losses.avg\n\n\ndef get_model_parameters_count(model):\n    return np.sum([np.prod([int(e) for e in p.shape]) for p in model.parameters()])\n\n\ndef train(manager,\n          train_loader,\n          test_loader,\n          start_iter,\n          disp_iter=100,\n          save_iter=10000,\n          valid_iter=1000,\n          use_cuda=False,\n          loss=None):\n    \"\"\"train loop\"\"\"\n\n    device = torch.device('cpu' if not use_cuda else 'cuda')\n    model, optimizer = manager.model, manager.optimizer\n\n    logger.info('Model parameters: {}'.format(get_model_parameters_count(model)))\n\n    if use_cuda:\n        model_mem_allocation = torch.cuda.memory_allocated(device)\n        logger.info('Model memory allocation: {}'.format(model_mem_allocation))\n    else:\n        model_mem_allocation = None\n\n    writer = SummaryWriter(manager.log_dir)\n    data_time = AverageMeter()\n    batch_time = AverageMeter()\n    losses = AverageMeter()\n    top1 = AverageMeter()\n    act_mem_activations = AverageMeter()\n\n    ceriterion = loss\n    # ensure train_loader enumerates to max_epoch\n    max_iterations = train_loader.sampler.nsamples // train_loader.batch_size\n    train_loader.sampler.nsamples = train_loader.sampler.nsamples - start_iter\n    end = time.time()\n    for ind, (x, label) in enumerate(train_loader):\n        iteration = ind + 1 + start_iter\n\n        if iteration > max_iterations:\n            logger.info('maximum number of iterations reached: {}/{}'.format(iteration, max_iterations))\n            break\n\n        if iteration == 40000 or iteration == 60000:\n            for param_group in optimizer.param_groups:\n                param_group['lr'] *= 0.1\n\n        model.train()\n\n        data_time.update(time.time() - end)\n        end = time.time()\n        x, label = x.to(device), label.to(device)\n        vx, vl = x, label\n\n        score = model(vx)\n        loss = ceriterion(score, vl)\n\n        if use_cuda:\n            activation_mem_allocation = torch.cuda.memory_allocated(device) - model_mem_allocation\n            act_mem_activations.update(activation_mem_allocation, iteration)\n\n        if torch.isnan(loss):\n            raise ValueError(\"Loss became NaN during iteration {}\".format(iteration))\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        batch_time.update(time.time()-end)\n        prec1 = accuracy(score.data, label)\n\n        losses.update(loss.item(), x.size(0))\n        top1.update(prec1[0][0], x.size(0))\n\n        if iteration % disp_iter == 0:\n            act = ''\n            if model_mem_allocation is not None:\n                act = 'ActMem {act.val:.3f} ({act.avg:.3f})'.format(act=act_mem_activations)\n            logger.info('iteration: [{0}/{1}]\\t'\n                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n                        'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\\t'\n                        '{act}'\n                        .format(iteration, max_iterations,\n                                batch_time=batch_time, data_time=data_time,\n                                loss=losses, top1=top1, act=act))\n\n        if iteration % disp_iter == 0:\n            writer.add_scalar('train_loss', loss.item(), iteration)\n            writer.add_scalar('train_acc', prec1[0][0], iteration)\n            losses.reset()\n            top1.reset()\n            data_time.reset()\n            batch_time.reset()\n            if use_cuda:\n                writer.add_scalar('act_mem_allocation', act_mem_activations.avg, iteration)\n                act_mem_activations.reset()\n\n        if iteration % valid_iter == 0:\n            test_top1, test_loss = validate(model, ceriterion, test_loader, device=device)\n            writer.add_scalar('test_loss', test_loss, iteration)\n            writer.add_scalar('test_acc', test_top1, iteration)\n\n        if iteration % save_iter == 0:\n            manager.save_train_state(iteration)\n            writer.flush()\n\n        end = time.time()\n\n    writer.close()\n"
  },
  {
    "path": "memcnn/trainers/tests/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/trainers/tests/resources/experiments.json",
    "content": "{\n    \"testsetup\": {\n        \"model\": \"memcnn.trainers.tests.test_train.DummyModel\",\n        \"model_params\": {\n            \"block\":\"memcnn.trainers.tests.test_train.DummyDataset\"\n        },\n        \"optimizer\": \"torch.optim.SGD\",\n        \"optimizer_params\": {\n            \"lr\":0.1\n        },\n        \"trainer\": \"memcnn.trainers.tests.test_train.dummy_trainer\",\n        \"trainer_params\": {\n            \"loss\":\"memcnn.trainers.tests.test_train.DummyDataset\"\n        },\n        \"data_loader\": \"memcnn.trainers.tests.test_train.dummy_dataloaders\",\n        \"data_loader_params\":\n        {\n            \"dataset\": \"memcnn.trainers.tests.test_train.DummyDataset\",\n            \"workers\": 0\n        }\n    },\n\n    \"resnet32\":\n    {\n        \"data_loader_params\": {\n            \"batch_size\": 100,\n            \"max_epoch\": 80000\n        },\n        \"model\": \"memcnn.models.resnet.ResNet\",\n        \"model_params\": {\n            \"block\":\"memcnn.models.resnet.BasicBlock\",\n            \"layers\":[5, 5, 5],\n            \"channels_per_layer\":[16,16,32,64],\n            \"strides\":[1, 1, 2, 2],\n            \"init_max_pool\":false,\n            \"init_kernel_size\":3,\n            \"batch_norm_fix\":false\n        },\n        \"optimizer\": \"torch.optim.SGD\",\n        \"optimizer_params\": {\n            \"lr\":0.1,\n            \"momentum\":0.9,\n            \"weight_decay\":2e-4\n        },\n        \"trainer\":\"memcnn.trainers.classification.train\",\n        \"trainer_params\":{\n            \"loss\":\"memcnn.utils.loss.CrossEntropyLossTF\"\n        }\n    },\n\n    \"resnet110\":\n    {\n        \"base\": \"resnet32\",\n        \"model_params\": {\n            \"layers\":[18, 18, 18]\n        }\n    },\n\n    \"resnet164\":\n    {\n        \"base\": \"resnet110\",\n        \"model_params\": {\n            \"block\":\"memcnn.models.resnet.Bottleneck\"\n        }\n    },\n\n    \"revnet38\":\n    {\n        \"base\": \"resnet32\",\n        \"model_params\": {\n            \"layers\":[3, 3, 3],\n            \"channels_per_layer\":[32,32,64,112],\n            \"block\":\"memcnn.models.resnet.RevBasicBlock\"\n        }\n    },\n\n    \"revnet110\":\n    {\n        \"base\": \"revnet38\",\n        \"model_params\": {\n            \"layers\":[9, 9, 9],\n            \"channels_per_layer\":[32,32,64,128]\n        }\n    },\n\n    \"revnet164\":\n    {\n        \"base\": \"revnet110\",\n        \"model_params\": {\n            \"block\":\"memcnn.models.resnet.RevBottleneck\"\n        }\n    },\n\n    \"cifar10\":\n    {\n        \"data_loader\": \"memcnn.data.cifar.get_cifar_data_loaders\",\n        \"data_loader_params\": {\n            \"dataset\": \"torchvision.datasets.CIFAR10\",\n            \"workers\": 16\n        },\n        \"model_params\": {\n            \"num_classes\":10\n        }\n    },\n\n    \"cifar100\":\n    {\n        \"data_loader\": \"memcnn.data.cifar.get_cifar_data_loaders\",\n        \"data_loader_params\": {\n            \"dataset\": \"torchvision.datasets.CIFAR100\",\n            \"workers\": 16\n        },\n        \"model_params\": {\n            \"num_classes\":100\n        }\n    },\n\n    \"epoch5\":\n    {\n        \"data_loader_params\": {\n            \"max_epoch\": 5\n        }\n    }\n}\n"
  },
  {
    "path": "memcnn/trainers/tests/test_classification.py",
    "content": "import pytest\n\nfrom memcnn.trainers.classification import train\nfrom memcnn.experiment.manager import ExperimentManager\nfrom memcnn.data.cifar import get_cifar_data_loaders\nfrom memcnn.utils.loss import CrossEntropyLossTF\nimport torch\nfrom torchvision.datasets.cifar import CIFAR10\n\n\nclass SimpleTestingModel(torch.nn.Module):\n    def __init__(self, klasses):\n        super(SimpleTestingModel, self).__init__()\n        self.conv = torch.nn.Conv2d(3, klasses, 1)\n        self.avgpool = torch.nn.AvgPool2d(32)\n        self.klasses = klasses\n\n    def forward(self, x):\n        return self.avgpool(self.conv(x)).reshape(x.shape[0], self.klasses)\n\n\ndef test_train(tmp_path):\n    expdir = str(tmp_path / \"testexp\")\n    tmp_data_dir = str(tmp_path / \"tmpdata\")\n    num_klasses = 10\n\n    model = SimpleTestingModel(num_klasses)\n    optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)\n    manager = ExperimentManager(expdir, model, optimizer)\n    manager.make_dirs()\n\n    train_loader, test_loader = get_cifar_data_loaders(CIFAR10, tmp_data_dir, 40000, 2, 0)\n    loss = CrossEntropyLossTF()\n\n    train(manager,\n          train_loader,\n          test_loader,\n          start_iter=39999,\n          disp_iter=1,\n          save_iter=1,\n          valid_iter=1,\n          use_cuda=False,\n          loss=loss)\n\n\ndef test_train_with_nan_loss(tmp_path):\n    class NanLoss(torch.nn.Module):\n        def __init__(self):\n            super(NanLoss, self).__init__()\n\n        def forward(self, Ypred, Y, W=None):\n            return Ypred.mean() * float('nan')\n\n    expdir = str(tmp_path / \"testexp\")\n    tmp_data_dir = str(tmp_path / \"tmpdata\")\n    num_klasses = 10\n\n    model = SimpleTestingModel(num_klasses)\n    optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)\n    manager = ExperimentManager(expdir, model, optimizer)\n    manager.make_dirs()\n\n    train_loader, test_loader = get_cifar_data_loaders(CIFAR10, tmp_data_dir, 40000, 2, 0)\n    loss = NanLoss()\n\n    with pytest.raises(ValueError) as e:\n        train(manager,\n              train_loader,\n              test_loader,\n              start_iter=1,\n              disp_iter=1,\n              save_iter=1,\n              valid_iter=1,\n              use_cuda=False,\n              loss=loss)\n    assert \"Loss became NaN during iteration\" in str(e.value)\n"
  },
  {
    "path": "memcnn/trainers/tests/test_train.py",
    "content": "import json\n\nimport pytest\nimport os\nimport sys\nimport torch\n\nfrom memcnn.experiment.manager import ExperimentManager\nfrom memcnn.train import run_experiment, main\ntry:\n    from pathlib2 import Path\nexcept ImportError:\n    from pathlib import Path\n\n\ndef test_main(tmp_path):\n    sys.argv = ['train.py', 'cifar10', 'resnet34', '--fresh', '--no-cuda', '--workers=0']\n    data_dir = str(tmp_path / \"tmpdata\")\n    results_dir = str(tmp_path / \"resdir\")\n    os.makedirs(data_dir)\n    os.makedirs(results_dir)\n    with pytest.raises(KeyError):\n        main(data_dir=data_dir, results_dir=results_dir)\n\n\ndef dummy_dataloaders(*args, **kwargs):\n    return None, None\n\n\ndef dummy_trainer(manager, *args, **kwargs):\n    manager.save_train_state(2)\n\n\nclass DummyDataset(object):\n    def __init__(self, *args, **kwargs):\n        pass\n\n\nclass DummyModel(torch.nn.Module):\n    def __init__(self, block):\n        super(DummyModel, self).__init__()\n        self.block = block\n        self.conv = torch.nn.Conv2d(1, 1, 1)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\ndef test_run_experiment(tmp_path):\n    exptags = ['testsetup']\n    exp_file = str(Path(__file__).parent / \"resources\" / \"experiments.json\")\n    data_dir = str(tmp_path / \"tmpdata\")\n    results_dir = str(tmp_path / \"resdir\")\n    run_params = dict(\n        experiment_tags=exptags, data_dir=data_dir, results_dir=results_dir,\n        start_fresh=True, use_cuda=False, workers=None, experiments_file=exp_file\n    )\n    with pytest.raises(RuntimeError):\n        run_experiment(**run_params)\n    os.makedirs(data_dir)\n    with pytest.raises(RuntimeError):\n        run_experiment(**run_params)\n    os.makedirs(results_dir)\n    run_experiment(**run_params)\n    run_params[\"start_fresh\"] = False\n    run_experiment(**run_params)\n\n\n@pytest.mark.parametrize(\"network\", [\n    pytest.param(network,\n                 marks=pytest.mark.skipif(\n                     condition=(\"FULL_NETWORK_TESTS\" not in os.environ) and (\"revnet38\" != network),\n                     reason=\"Too memory intensive for CI so these tests are disabled by default. \"\n                            \"Set FULL_NETWORK_TESTS environment variable to enable the tests.\")\n                 )\n    for network in [\"resnet32\", \"resnet110\", \"resnet164\", \"revnet38\", \"revnet110\", \"revnet164\"]\n])\n@pytest.mark.parametrize(\"use_cuda\", [\n    False,\n    pytest.param(True, marks=pytest.mark.skipif(condition=not torch.cuda.is_available(), reason=\"No GPU available\"))\n])\ndef test_train_networks(tmp_path, network, use_cuda):\n    exptags = [\"cifar10\", network, \"epoch5\"]\n    exp_file = str(Path(__file__).parent / \"resources\" / \"experiments.json\")\n    data_dir = str(tmp_path / \"tmpdata\")\n    results_dir = str(tmp_path / \"resdir\")\n    os.makedirs(data_dir)\n    os.makedirs(results_dir)\n    run_experiment(experiment_tags=exptags, data_dir=data_dir, results_dir=results_dir,\n                   start_fresh=True, use_cuda=use_cuda, workers=None, experiments_file=exp_file,\n                   disp_iter=1,\n                   save_iter=5,\n                   valid_iter=5,)\n    experiment_dir = os.path.join(results_dir, '_'.join(exptags))\n    assert os.path.exists(experiment_dir)\n    manager = ExperimentManager(experiment_dir)\n    scalars_file = os.path.join(manager.log_dir, \"scalars.json\")\n    assert os.path.exists(scalars_file)\n    with open(scalars_file, \"r\") as f:\n        results = json.load(f)\n    # no results should hold any NaN values\n    assert not any([val != val for t, i, val in results[\"train_loss\"]])\n"
  },
  {
    "path": "memcnn/utils/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/utils/log.py",
    "content": "import os\nimport json\nimport logging\nimport sys\nimport time\n\n\ndef setup(use_stdout=True, filename=None, log_level=logging.DEBUG):\n    \"\"\"setup some basic logging\"\"\"\n\n    log = logging.getLogger('')\n    log.setLevel(log_level)\n    fmt = logging.Formatter(\"%(asctime)s [%(name)-15s] %(message)s\", datefmt=\"%y-%m-%d %H:%M:%S\")\n\n    if use_stdout:\n        ch = logging.StreamHandler(sys.stdout)\n        ch.setLevel(log_level)\n        ch.setFormatter(fmt)\n        log.addHandler(ch)\n\n    if filename is not None:\n        fh = logging.FileHandler(filename)\n        fh.setLevel(log_level)\n        fh.setFormatter(fmt)\n        log.addHandler(fh)\n\n\nclass SummaryWriter(object):\n    def __init__(self, log_dir):\n        self._log_dir = log_dir\n        self._log_file = os.path.join(log_dir, \"scalars.json\")\n        self._summary = {}\n        self._load_if_exists()\n\n    def _load_if_exists(self):\n        if os.path.exists(self._log_file):\n            with open(self._log_file, \"r\") as f:\n                self._summary = json.load(f)\n\n    def add_scalar(self, name, value, iteration):\n        if name not in self._summary:\n            self._summary[name] = []\n        self._summary[name].append([time.time(), int(iteration), float(value)])\n\n    def flush(self):\n        with open(self._log_file, \"w\") as f:\n            json.dump(self._summary, f)\n\n    def close(self):\n        self.flush()\n"
  },
  {
    "path": "memcnn/utils/loss.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn.modules.module import Module\n\n\ndef _assert_no_grad(variable):\n    msg = \"nn criterions don't compute the gradient w.r.t. targets - please \" \\\n          \"mark these variables as not requiring gradients\"\n    assert not variable.requires_grad, msg  # nosec\n\n\nclass CrossEntropyLossTF(Module):\n    def __init__(self):\n        super(CrossEntropyLossTF, self).__init__()\n\n    def forward(self, Ypred, Y, W=None):\n        _assert_no_grad(Y)\n        lsm = nn.Softmax(dim=1)\n        y_onehot = torch.zeros(Ypred.shape[0], Ypred.shape[1], dtype=torch.float32, device=Ypred.device)\n        y_onehot.scatter_(1, Y.data.view(-1, 1), 1)\n        if W is not None:\n            y_onehot = y_onehot * W\n        return torch.mean(-y_onehot * torch.log(lsm(Ypred))) * Ypred.shape[1]\n"
  },
  {
    "path": "memcnn/utils/stats.py",
    "content": "\"\"\" Module containing utilities to compute statistics\n\nSome bits from: https://gist.github.com/xmfbit/67c407e34cbaf56e7820f09e774e56d8\n\"\"\"\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n# top-k accuracy\ndef accuracy(output, target, topk=(1,)):\n    \"\"\"Computes the precision@k for the specified values of k\"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].contiguous().view(-1).float().sum(dim=0, keepdim=True)\n        res.append(correct_k.mul_(100.0 / batch_size))\n    return res\n"
  },
  {
    "path": "memcnn/utils/tests/__init__.py",
    "content": ""
  },
  {
    "path": "memcnn/utils/tests/test_log.py",
    "content": "import logging\nfrom memcnn.utils.log import setup, SummaryWriter\n\n\ndef test_setup(tmp_path):\n    logfile = str(tmp_path / 'testlog.log')\n    setup(use_stdout=True, filename=logfile, log_level=logging.DEBUG)\n\n\ndef test_summary_writer(tmp_path):\n    logfile = tmp_path / 'scalars.json'\n\n    assert not logfile.exists()\n    writer = SummaryWriter(log_dir=str(tmp_path))\n    writer.add_scalar(\"test_value\", 0.5, 1)\n    writer.add_scalar(\"test_value\", 2.5, 2)\n    writer.add_scalar(\"test_value2\", 123, 1)\n    writer.flush()\n    assert logfile.exists()\n\n    writer = SummaryWriter(log_dir=str(tmp_path))\n\n    assert \"test_value\" in writer._summary\n    assert \"test_value2\" in writer._summary\n    assert len(writer._summary[\"test_value\"]) == 2\n\n    writer.add_scalar(\"test_value\", 123.4, 3)\n    writer.close()\n\n    writer = SummaryWriter(log_dir=str(tmp_path))\n\n    assert \"test_value\" in writer._summary\n    assert \"test_value2\" in writer._summary\n    assert len(writer._summary[\"test_value\"]) == 3\n"
  },
  {
    "path": "memcnn/utils/tests/test_loss.py",
    "content": "import torch\nfrom memcnn.utils.loss import _assert_no_grad, CrossEntropyLossTF\n\n\ndef test_assert_no_grad():\n    data = torch.ones(3, 3, 3)\n    data.requires_grad = False\n    _assert_no_grad(data)\n\n\ndef test_crossentropy_tf():\n    batch_size = 5\n    shape = (batch_size, 2)\n    loss = CrossEntropyLossTF()\n    ypred = torch.ones(*shape)\n    ypred.requires_grad = True\n    y = torch.ones(batch_size, dtype=torch.int64)\n    y.requires_grad = False\n    w = torch.ones(*shape)\n    w.requires_grad = False\n    w2 = torch.zeros(*shape)\n    w2.requires_grad = False\n\n    out1 = loss(ypred, y)\n    assert len(out1.shape) == 0\n\n    out2 = loss(ypred, y, w)\n    assert len(out2.shape) == 0\n\n    out3 = loss(ypred, y, w2)\n    assert out3 == 0\n    assert len(out3.shape) == 0\n"
  },
  {
    "path": "memcnn/utils/tests/test_stats.py",
    "content": "import pytest\nimport torch\nfrom memcnn.utils.stats import AverageMeter, accuracy\n\n\n@pytest.mark.parametrize('val,n', [(1, 1), (14, 10), (10, 14), (5, 1), (1, 5), (0, 10)])\ndef test_average_meter(val, n):\n    meter = AverageMeter()\n    assert meter.val == 0\n    assert meter.avg == 0\n    assert meter.sum == 0\n    assert meter.count == 0\n    meter.update(val, n=n)\n    assert meter.val == val\n    assert meter.avg == val\n    assert meter.sum == val * n\n    assert meter.count == n\n\n\n@pytest.mark.parametrize('topk,klass', [((1,), 4), ((1, 3,), 2), ((5,), 1)])\ndef test_accuracy(topk, klass, num_klasses=5):  # output, target,\n    batch_size = 5\n    target = torch.ones(batch_size, dtype=torch.long) * klass\n    output = torch.zeros(batch_size, num_klasses)\n    output[:, klass] = 1\n    res = accuracy(output, target, topk)\n    assert len(res) == len(topk)\n    assert all([e == 100.0 for e in res])\n"
  },
  {
    "path": "paper/README",
    "content": "The paper can be compiled locally using the following command:\n\npandoc paper.md --bibliography paper.bib -o paper_local.pdf\n"
  },
  {
    "path": "paper/paper.bib",
    "content": "@misc{Gomez17,\n    author    = {A. N. Gomez and\n               M. Ren and\n               R. Urtasun and\n               R. B. Grosse},\n    title     = {The Reversible Residual Network: Backpropagation Without Storing Activations},\n    howpublished = {{\\tt arXiv:1707.04585 [cs.CV]}},\n    url = {https://arxiv.org/abs/1707.04585},\n    year      = 2017\n}\n\n@misc{Dinh14,\n    author    = {L. Dinh and\n               D. Krueger and\n               Y. Bengio},\n    title     = {{NICE:} Non-linear Independent Components Estimation},\n    howpublished = {{\\tt arXiv:1410.8516 [cs.LG]}},\n    url = {https://arxiv.org/abs/1410.8516},\n    year      = 2014\n}\n\n@article{He2016,\n    title = {Identity Mappings in Deep Residual Networks},\n    note = {{\\tt arXiv:1603.05027 [cs.CV]}},\n    year = 2016,\n    doi = {10.1007/978-3-319-46493-0_38},\n    publisher = {Springer International Publishing},\n    pages = {630--645},\n    author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},\n    booktitle = {Computer Vision {\\textendash} {ECCV} 2016}\n}\n\n@article{He2015,\n    note = {{\\tt arXiv:1512.03385 [cs.CV]}},\n    doi = {10.1109/cvpr.2016.90},\n    year = {2016},\n    month = jun,\n    publisher = {{IEEE}},\n    author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},\n    title = {Deep Residual Learning for Image Recognition},\n    booktitle = {2016 {IEEE} Conference on Computer Vision and Pattern Recognition ({CVPR})}\n}\n\n@misc{Chang17,\n    author    = {B. Chang and\n               L. Meng and\n               E. Haber and\n               L. Ruthotto and\n               D. Begert and\n               E. Holtham},\n    title     = {Reversible Architectures for Arbitrarily Deep Residual Neural Networks},\n    howpublished = {{\\tt arXiv:1709.03698 [cs.CV]}},\n    url = {https://arxiv.org/abs/1709.03698},\n    year      = 2017\n}\n\n@mastersthesis{krizhevsky2009learning,\n    author = {Krizhevsky, A.},\n    title = {Learning Multiple Layers of Features from Tiny Images},\n    school       = {University of Toronto},\n    year         = 2009,\n    address      = {Toronto, Ontario, Canada},\n    month        = apr,\n}\n\n@inproceedings{imagenet_cvpr09,\n    doi = {10.1109/cvprw.2009.5206848},\n    year = 2009,\n    month = jun,\n    publisher = {{IEEE}},\n    author = {Jia Deng and Wei Dong and Richard Socher and Li-Jia Li and  Kai Li and  Li Fei-Fei},\n    title = {{ImageNet}: A large-scale hierarchical image database},\n    booktitle = {2009 {IEEE} Conference on Computer Vision and Pattern Recognition}\n}\n\n@inproceedings{jaco18,\n    author = {J.-H. Jacobsen and A.W.M. Smeulders and E. Oyallon},\n    title = {{i-RevNet}: Deep Invertible Networks},\n    booktitle = {ICLR},\n    year = {2018},\n    url = {https://arxiv.org/abs/1802.07088},\n    howpublished = {{\\tt arXiv:1802.07088 [cs.LG]}}\n}\n\n@misc{TF2015,\n    title={{TensorFlow}: Large-Scale Machine Learning on Heterogeneous Systems},\n    note={Software available from tensorflow.org},\n    author={\n        M.~Abadi and\n        A.~Agarwal and\n        P.~Barham and\n        E.~Brevdo and\n        Z.~Chen and\n        C.~Citro and\n        G.~S..~Corrado and\n        A.~Davis and\n        J.~Dean and\n        M.~Devin and\n        S.~Ghemawat and\n        I.~Goodfellow and\n        A.~Harp and\n        G.~Irving and\n        M.~Isard and\n        Y.Jia and\n        R.~Jozefowicz and\n        L.~Kaiser and\n        M.~Kudlur and\n        J.~Levenberg and\n        D.~Man\\'{e} and\n        R.~Monga and\n        S.~Moore and\n        D.~Murray and\n        C.~Olah and\n        M.~Schuster and\n        J.~Shlens and\n        B.~Steiner and\n        I.~Sutskever and\n        K.~Talwar and\n        P.~Tucker and\n        V.~Vanhoucke and\n        V.~Vasudevan and\n        F.~Vi\\'{e}gas and\n        O.~Vinyals and\n        P.~Warden and\n        M.~Wattenberg and\n        M.~Wicke and\n        Y.~Yu and\n        X.~Zheng},\n    year={2015},\n    url={http://tensorflow.org/},\n}\n\n@inproceedings{paszke2017automatic,\n    title={Automatic differentiation in {PyTorch}},\n    author={Paszke, A. and Gross, S. and Chintala, S. and Chanan, G. and Yang, E. and DeVito, Z. and Lin, Z. and Desmaison, A. and Antiga, L. and Lerer, A.},\n    booktitle={NIPS-W},\n    year={2017},\n    howpublished = {\\url{https://openreview.net/forum?id=BJJsrmfCZ}},\n}\n\n@inproceedings{kingma2018glow,\n    title = {Glow: Generative Flow with Invertible 1x1 Convolutions},\n    author = {Kingma, Durk P and Dhariwal, Prafulla},\n    booktitle = {Advances in Neural Information Processing Systems 31},\n    editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett},\n    pages = {10215--10224},\n    year = {2018},\n    publisher = {Curran Associates, Inc.},\n    url = {http://papers.nips.cc/paper/8224-glow-generative-flow-with-invertible-1x1-convolutions.pdf}\n}\n\n@misc{dinh2016density,\n    title={Density estimation using {Real NVP}},\n    author={Dinh, Laurent and Sohl-Dickstein, Jascha and Bengio, Samy},\n    year={2016},\n    howpublished = {{\\tt arXiv:1605.08803 [cs.LG]}},\n    url = {https://arxiv.org/abs/1605.08803}\n}\n\n@incollection{martens2012training,  \n    doi = {10.1007/978-3-642-35289-8_27},\n    year = 2012,\n    publisher = {Springer Berlin Heidelberg},\n    pages = {479--535},\n    author = {James Martens and Ilya Sutskever},\n    title = {Training Deep and Recurrent Networks with Hessian-Free Optimization},\n    booktitle = {Neural Networks: Tricks of the Trade: Second Edition}\n}\n\n@misc{chen2016training,\n    title={Training deep nets with sublinear memory cost},\n    author={Chen, Tianqi and Xu, Bing and Zhang, Chiyuan and Guestrin, Carlos},\n    howpublished = {{\\tt arXiv:1604.06174 [cs.LG]}},\n    url = {https://arxiv.org/abs/1604.06174},\n    year=2016\n}\n\n@InProceedings{Ouderaa_2019_CVPR,\n    author = {Ouderaa, Tycho F.A. van der and Worrall, Daniel E.},\n    title = {Reversible {GANs} for Memory-Efficient Image-To-Image Translation},\n    booktitle = {{The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}},\n    month = jun,\n    year = 2019,\n    howpublished = {{\\tt arXiv:1902.02729 [cs.CV]}},\n    url = {https://arxiv.org/abs/1902.02729},\n}\n\n@inproceedings{zhu2017unpaired,\n    title={Unpaired image-to-image translation using cycle-consistent adversarial networks},\n    author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},\n    booktitle={{Proceedings of the IEEE International Conference on Computer Vision}},\n    pages={2223--2232},\n    year={2017},\n    doi={10.1109/iccv.2017.244},\n}\n\n@inproceedings{ouderaa:MIDLAbstract2019a,\n    title={Chest {CT} Super-resolution and Domain-adaptation using Memory-efficient 3D Reversible {GAN}s},\n    author={Tycho F.A. van der Ouderaa and Daniel E. Worrall and Bram van Ginneken},\n    booktitle={International Conference on Medical Imaging with Deep Learning},\n    address={London, United Kingdom},\n    year=2019,\n    month=jul,\n    url={https://openreview.net/forum?id=SkxueFsiFV}\n}\n"
  },
  {
    "path": "paper/paper.md",
    "content": "---\ntitle: 'MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks'\ntags:\n  - MemCNN\n  - Python\n  - PyTorch\n  - machine learning\n  - invertible networks\n  - deep learning  \nauthors:\n  - name: Sil C. van de Leemput\n    orcid: 0000-0001-6047-3051\n    affiliation: 1\n  - name: Jonas Teuwen\n    affiliation: 1\n  - name: Bram van Ginneken\n    affiliation: 1\n  - name: Rashindra Manniesing\n    affiliation: 1    \naffiliations:\n  - name: Radboud University Medical Center, Department of Radiology and Nuclear Medicine, Nijmegen, The Netherlands\n    index: 1\ndate: 28 June 2019\nbibliography: paper.bib\n---\n\n# Summary\n\nNeural networks are computational models that were originally inspired by biological neural networks like animal brains. These networks are composed of many small computational units called neurons that perform elementary calculations. Instead of explicitly programming the behavior of neural networks, these models can be trained to perform tasks, like classifying images, by presenting them examples. Sufficiently complex neural networks can automatically extract task-relevant characteristics from the presented examples without having prior knowledge about the task domain, which makes them attractive for many complicated real-world applications.   \n\nReversible operations have recently been successfully applied to classification problems to reduce memory requirements during neural network training. This feature is accomplished by removing the need to store the input activation for computing the gradients at the backward pass and instead reconstruct them on demand. However, current approaches rely on custom implementations of backpropagation, which limits applicability and extendibility. We present MemCNN, a novel PyTorch framework that simplifies the application of reversible functions by removing the need for a customized backpropagation. The framework contains a set of practical generalized tools, which can wrap common operations like convolutions and batch normalization and which take care of memory management. We validate the presented framework by reproducing state-of-the-art experiments using MemCNN and by comparing classification accuracy and training time on Cifar-10 and Cifar-100. Our MemCNN implementations achieved similar classification accuracy and faster training times while retaining compatibility with the default backpropagation facilities of PyTorch.\n\n# Background\n\nReversible functions, which allow exact retrieval of its input from its output, can reduce memory overhead when used within the context of training neural networks using backpropagation. That is since only the output requires to be stored, intermediate feature maps can be freed on the forward pass and recomputed from the output on the backward pass when required. Recently, reversible functions have been used with some success to extend the well established residual network (ResNet) for image classification from @He2015 to more memory efficient invertible convolutional neural networks [@Gomez17; @Chang17; @jaco18] showing competitive performance on datasets like Cifar-10, Cifar-100 [@krizhevsky2009learning] and ImageNet [@imagenet_cvpr09]. However, practical applicability and extendibility of reversible functions for the reduction of memory overhead have been limited, since current implementations require customized backpropagation, which does not work conveniently with modern deep learning frameworks and requires substantial manual design. \n\nThe reversible residual network (RevNet) of @Gomez17 is a variant on ResNet, which hooks into its sequential structure of residual blocks and replaces them with reversible blocks, that creates an explicit inverse for the residual blocks based on the equations from @Dinh14 on nonlinear independent components estimation. The reversible block takes arbitrary nonlinear functions $\\mathcal{F}$ and $\\mathcal{G}$ and renders them invertible. Their experiments show that RevNet scores similar classification performance on Cifar-10, Cifar-100, and ImageNet, with less memory overhead. \n\nReversible architectures like RevNet have subsequently been studied in the framework of ordinary differential equations (ODE) [@Chang17]. Three reversible neural networks based on Hamiltonian systems are proposed, which are similar to the RevNet, but have a specific choice for the nonlinear functions $\\mathcal{F}$ and $\\mathcal{G}$ which are shown stable during training within the ODE framework on Cifar-10 and Cifar-100.\n\nThe i-RevNet architecture extends the RevNet architecture by also making the downscale operations invertible [@jaco18], effectively creating a fully invertible architecture up until the last layer, while still showing good classification accuracy compared to ResNet on ImageNet. One particularly interesting finding shows that bottlenecks are not a necessary condition for training neural networks, which shows that the study of invertible networks can lead to a better understanding of neural network training in general.\n\nThe different reversible architectures proposed in the literature [@Gomez17; @Chang17; @jaco18] have all been modifications of the ResNet architecture and all have been implemented in TensorFlow [@TF2015]. However, these implementations rely on custom backpropagation, which limits creating novel invertible networks and application of the concepts beyond the application architecture. Our proposed framework MemCNN overcomes this issue by being compatible with the default backpropagation facilities of PyTorch. Furthermore, PyTorch offers convenient features over other deep learning frameworks like a dynamic computation graph and simple inspection of gradients during backpropagation, which facilitates inspection of invertible operations in neural networks.\n\n# Methods\n\n## The reversible block\nThe core operator of MemCNN is the reversible block which is an operator which takes a function $f$ and outputs a function $R : X \\to Y$, and an inverse function $R^{-1} : Y \\rightarrow{X}$ which resembles an invertible version of $f$. Here, $x\\in X$ and $y\\in Y$ can be arbitrary tensors with the same size and number of dimension, i.e.: $\\operatorname{shape}(x)=\\operatorname{shape}(y)$. Additionally, it must be possible to partition the input $x=(x_1, x_2)$ and output tensors $y=(y_1, y_2)$ in half, where each partition has the same shape, i.e.: $\\operatorname{shape}(x_1) = \\operatorname{shape}(x_2) = \\operatorname{shape}(y_1) = \\operatorname{shape}(y_2)$. Formally, the reversible block operation (1), its inverse (2), and its partition constraints (3) provide a sufficiently general framework for implementing reversible operations. \n\nFor example, if one wants to create a reversible block performing a convolution followed by a ReLu $f$, the input $x \\in X$ is partitioned in $(x_1, x_2)$ of equal sizes to which this convolution block $f$ is applied twice (say $\\mathcal{F}$ and $\\mathcal{G}$). The Reversible Block takes these two operators ($\\mathcal{F}$ and $\\mathcal{G}$) and outputs a \"resblock\"-like version $R$ of the operator and an explicit inverse $R^{-1}$. Effectively the learnable function $f$ is replaced by a learnable approximation $R$ with an explicit inverse $R^{-1}$.\n\n\\begin{equation} \\quad R(x) = y \\end{equation}\n\\begin{equation} R^{-1}(y)  = x  \\end{equation}\nwith\n\\begin{equation} \\operatorname{shape}(x_1) = \\operatorname{shape}(x_2) = \\operatorname{shape}(y_1) = \\operatorname{shape}(y_2) \\end{equation}\n\n## Couplings\n\nUsing the above definitions we provide two different implementations for the reversible block in MemCNN, which we will call `couplings'. A coupling provides a reversible mapping from $(x_1, x_2)$ to $(y_1, y_2)$. MemCNN supports two couplings: the additive coupling and the affine coupling.\n\n### Additive coupling\n\nEquation 4 represents the additive coupling, which follows the equations of @Dinh14 and @Gomez17. These support a reversible implementation through arbitrary (nonlinear) functions $\\mathcal{F}$ and $\\mathcal{G}$. These functions can be convolutions, ReLus, etc., as long as they have matching input and output shapes. The additive coupling is obtained by first computing $y_1$ from input partitions $x_1, x_2$ and function $\\mathcal{F}$ and subsequently $y_2$ is computed from partitions $y_1, x_2$ and function $\\mathcal{G}$. Next, (4) can be rewritten to obtain an exact inverse function as shown in (5). Figure 1 shows a graphical representation of the additive coupling and its inverse.\n\n![Graphical representation of additive coupling. The left graph shows the forward computations and the right graph shows its inverse. First, input $x_1$ and $\\mathcal{F}(x_2)$ are added to form $y_1$, next $x_2$ and $\\mathcal{G}(y_1)$ are added to form $y_2$. Going backwards, first, $\\mathcal{G}(y_1)$ is subtracted from $y_2$ to obtain $x_2$; subsequently, $\\mathcal{F}(x_2)$ is subtracted from $y_1$ to obtain $x_1$. Here, $+$ and $-$ stand for respectively element-wise summation and element-wise subtraction.](additive_005.pdf)\n\n\\begin{equation}\\label{eq:additiveforward}\n\\begin{split}\ny_1 &= x_1 + \\mathcal{F}(x_2), \\\\\ny_2 &= x_2 + \\mathcal{G}(y_1) \\\\\n\\end{split}\n\\end{equation}\n\\begin{equation}\\label{eq:additivebackward}\n\\begin{split}\nx_2 &= y_2 - \\mathcal{G}(y_1), \\\\\nx_1 &= y_1 - \\mathcal{F}(x_2)\n\\end{split}\n\\end{equation}\n\n### Affine coupling\n\nEquation (6) gives the affine coupling, introduced by @dinh2016density and later used by @kingma2018glow, which is more expressive than the additive coupling. The affine coupling, similar to the additive coupling, supports a reversible implementations through arbitrary (nonlinear) functions $\\mathcal{F}$ and $\\mathcal{G}$. It also first computes $y_1$ from input partitions $x_1, x_2$ and function $\\mathcal{F}$ and subsequently it computes $y_2$ from partitions $y_1, x_2$ and function $\\mathcal{G}$. The difference with the additive coupling is that now the functions $\\mathcal{F}=(s,t)$ and $\\mathcal{G}=(s',t')$ each produce two equally sized partitions for scaling and translation, so $\\operatorname{shape}(x_1) = \\operatorname{shape}(s) = \\operatorname{shape}(t) = \\operatorname{shape}(s') = \\operatorname{shape}(t')$ holds. These components are then used to compute the output using element-wise product ($\\odot$) and element-wise exponentiation with base $e$ and element-wise addition ($+$). Equation (6) can be rewritten to obtain an exact inverse function as shown in (7), which uses element-wise division ($/$) and element-wise subtraction ($-$). Figure 2 shows a graphical representation of the affine coupling and its inverse. \n\n![Graphical representation of the affine coupling. The left graph shows the forward computations and the right graph shows its inverse. Here, $\\odot, /, +, -,$ and $e$ stand for element-wise multiplication, element-wise division, element-wise addition, element-wise subtraction, and element-wise exponentiation with base $e$ respectively. First, $s, t$ are computed for $\\mathcal{F}(x_2)$, next input $x_1$ is element-wise multiplied with $e^{s}$ and added to $t$ to form $y_1$, subsequently $s', t'$ are computed for $\\mathcal{G}(y_1)$ and then $x_2$ is element-wise multiplied with $e^{s'}$ and added to $t'$ to form $y_2$.](affine_005.pdf)\n\n\\begin{equation}\\label{eq:affineforward}\n\\begin{split}\ny_1 &= x_1 \\odot e^{s} + t \\;\\;\\;\\;\\, \\text{with} \\;\\; \\mathcal{F}(x_2) = (s, t)  \\\\\ny_2 &= x_2 \\odot e^{s'} + t' \\;\\;\\, \\text{with} \\;\\;\\; \\mathcal{G}(y_1) = (s', t')\n\\end{split}\n\\end{equation}\n\\begin{equation}\\label{eq:affinebackward}\n\\begin{split}\nx_2 &= (y_2 - t') / e^{s'} \\;\\;\\, \\text{with} \\;\\;\\; \\mathcal{G}(y_1) = (s', t') \\\\\nx_1 &= (y_1 - t) / e^{s} \\;\\;\\;\\:\\: \\text{with} \\;\\; \\mathcal{F}(x_2) = (s, t) \n\\end{split}\n\\end{equation}\n\n## Implementation details\n\nThe reversible block has been implemented as a \\texttt{torch.nn.Module} which wraps other PyTorch modules of arbitrary complexity for coupling functions $\\mathcal{F}$ and $\\mathcal{G}$. Each memory saving coupling is implemented using at least one \\texttt{torch.autograd.Function}, which provides a custom forward and backward pass that works with the automatic differentiation system of PyTorch. Memory savings are implemented at the level of the reversible block and are achieved by setting the size of the underlying tensor storage to zero for inputs on the forward pass and restoring the storage size to the original size on the backward pass once it is required for computing gradients.\n\n## Building larger networks\n\nThe reversible block $R$ can be chained by subsequent reversible blocks, e.g.: $R_3 \\circ R_2 \\circ R_1$ for reversible blocks $R_1, R_2, R_3$, which creates a fully reversible chain of operations (see Figure 3). Additionally, reversible blocks can be mixed with regular functions $f$, e.g. $f \\circ R$ or $R \\circ f$ for reversible block $R$ and regular function $f$. Note that mixing regular functions with reversible blocks often breaks the invertibility of reversible chains. \n\n![Graphical representation of chaining multiple reversible block layers. ](coupling_001.pdf)\n\n## Memory savings\n\n**Table 1:** Comparison of memory and computational complexity for training a residual network (ResNet) between various memory saving techniques (extended table from @Gomez17). $L$ depicts the number of residual layers in the ResNet.\n\n\\begin{center}\n\n\\vspace{0.3cm}\n\n\\begin{tabular}{llp{2.5cm}p{2.5cm}}\n\\hline \\textbf{Technique} & \\textbf{Authors} & \\textbf{Memory Complexity} & \\textbf{Computational Complexity} \\\\ \\hline\nNaive & & $O(L)$ & $O(L)$ \\\\\nCheckpointing & Martens et al. (2012) & $O(\\sqrt{L})$ & $O(L)$ \\\\\nRecursive & Chen et al. (2016) & $O(\\log L)$ & $O(L \\log L)$ \\\\\nAdditive coupling & Gomez et al. (2017) & $O(1)$ & $O(L)$ \\\\\nAffine coupling & Dinh et al. (2016) & $O(1)$ & $O(L)$ \\\\ \\hline\n\\end{tabular}\n\n\\vspace{0.6cm}\n\n\\end{center}\n\nThe reversible block model has an advantageous memory footprint when chained in a sequence when training neural networks. After computing each $R(x) = y$ by (1) on the forward pass, input $x$ can be freed from memory and be recomputed on the backward pass, using the inverse function $R^{-1}(y)=x$ from (2). Once the input is restored, the gradients for the weights and the inputs can be recomputed as normal using the PyTorch `autograd' solver. This effectively yields a memory complexity of $O(1)$ in the number of chained reversible blocks. Table 1 shows a comparison of memory versus computational complexity for different memory saving techniques.\n\n\\break\n\n# Experiments and results\n\n **Table 2a:** Accuracy comparison of the PyTorch implementation (MemCNN) versus the Tensorflow implementation from @Gomez17 on Cifar-10 and Cifar-100 [@krizhevsky2009learning]. Accuracies were approximately similar between implementations.\n \n\\begin{center}\n \n\\vspace{0.3cm}\n \n\\begin{tabular}{lcccc}\n\\hline         & \\multicolumn{2}{c}{\\textbf{Cifar-10}} & \\multicolumn{2}{c}{\\textbf{Cifar-100}} \\\\\n\\textbf{Model} & \\textbf{Tensorflow} & \\textbf{PyTorch} & \\textbf{Tensorflow} & \\textbf{PyTorch} \\\\ \\hline\nResNet-32  & 92.74  & 92.86  & 69.10  & 69.81 \\\\\nResNet-110 & 93.99  & 93.55  & 73.30  & 72.40 \\\\\nResNet-164 & 94.57  & 94.80  & 76.79  & 76.47 \\\\\nRevNet-38  & 93.14  & 92.80  & 71.17  & 69.90 \\\\\nRevNet-110 & 94.02  & 94.10  & 74.00  & 73.30 \\\\\nRevNet-164 & 94.56  & 94.90  & 76.39  & 76.90 \\\\ \\hline\n\\end{tabular}\n\n\\vspace{0.3cm}\n\n\\end{center}\n\n**Table 2b:** Training time (in hours:minutes) comparison of the PyTorch implementation (MemCNN) versus the Tensorflow implementation from @Gomez17 on Cifar-10 and Cifar-100 [@krizhevsky2009learning]. Training times were significantly less for the PyTorch implementation than for the Tensorflow implementation.\n\n\\begin{center}\n\n\\vspace{0.3cm}\n\n\\begin{tabular}{lcccc}\n\\hline         & \\multicolumn{2}{c}{\\textbf{Cifar-10}} & \\multicolumn{2}{c}{\\textbf{Cifar-100}} \\\\\n\\textbf{Model} & \\textbf{Tensorflow} & \\textbf{PyTorch} & \\textbf{Tensorflow} & \\textbf{PyTorch} \\\\ \\hline\nResNet-32  & \\;\\,\\,2:04 & 1:51  & \\;\\,\\,1:58  & 1:51 \\\\\nResNet-110 & \\;\\,\\,4:11 & 2:51  & \\;\\,\\,6:44  & 2:39 \\\\\nResNet-164 & 11:05    & 4:59  & 10:59     & 3:45 \\\\\nRevNet-38  & \\;\\,\\,2:17 & 2:09  & \\;\\,\\,2:20  & 2:16 \\\\\nRevNet-110 & \\;\\,\\,6:59 & 3:42  & \\;\\,\\,7:03  & 3:50 \\\\\nRevNet-164 & 13:09    & 7:21  & 13:12     & 7:17 \\\\ \\hline\n\\end{tabular}\n\n\\vspace{0.6cm}\n\n\\end{center}\n\nTo validate MemCNN, we reproduced the experiments from @Gomez17 on Cifar-10 and Cifar-100 [@krizhevsky2009learning] using their Tensorflow [@TF2015] implementation on GitHub\\footnote{\\url{https://github.com/renmengye/revnet-public}}, and made a direct comparison with our PyTorch implementation on accuracy and train time. We have tried to keep all the experimental settings, like data loading, loss function, train procedure, and training parameters, as similar as possible. All experiments were performed on a single NVIDIA GeForce GTX 1080 with 8GB of RAM. The accuracies and training time results are listed in respectively Table 2a and Table 2b. Model performance of our PyTorch implementation obtained similar accuracy to the TensorFlow implementation with less training time on Cifar-10 and Cifar-100. All models and experiments are included in MemCNN and can be rerun for reproducibility.\n\nTable 3 shows memory usage statistics (parameters and activations) during training for all PyTorch models. Here, the ResNet model uses a conventional implementation and the RevNet model uses the reversible blocks from MemCNN. The results show that significant activation memory reduction was obtained using the reversible block implementation (RevNet) when the number of layers of the models increased. \n\n\\break\n\n **Table 3:** Model statistics for all PyTorch model implementations on memory usage (parameters and activations) in MB during training and the number of layers and parameters. The ResNet model was implemented using a conventional non-reversible implementation while the RevNet model uses MemCNN with memory saving reversible blocks. To facilitate comparison, each row lists the statistics of one ResNet and one RevNet model which have a comparable number of layers and number of parameters. Significant memory savings for the activations were observed when using reversible operations (RevNet) as the number of layers increased. Model parameter memory usage stayed roughly the same between implementations. \n\n\\begin{center}\n\n\\vspace{0.3cm}\n\n\\begin{tabular}{rr rr rr rr}\n\\hline \\multicolumn{2}{c}{\\textbf{Layers}} & \\multicolumn{2}{c}{\\textbf{Parameters}} & \\multicolumn{2}{c}{\\textbf{Parameters (MB)}} & \\multicolumn{2}{c}{\\textbf{Activations (MB)}} \\\\\n\\multicolumn{2}{c}{\\textbf{ResNet  RevNet}} & \\multicolumn{2}{c}{\\textbf{ResNet  RevNet}} & \\multicolumn{2}{c}{\\textbf{ResNet  RevNet}} & \\multicolumn{2}{c}{\\textbf{ResNet  RevNet}} \\\\ \\hline\n\\quad  32 &  38 \\quad \\quad &  466906 &  573994 & \\quad \\enspace 1.9 & 2.3 \\quad \\enspace & \\quad 238.6  & 85.6  \\enspace \\enspace \\\\\n\\quad 110 & 110 \\quad \\quad & 1730714 & 1854890 & \\quad \\enspace 6.8 & 7.3 \\quad \\enspace & \\quad 810.7  & 85.7  \\enspace \\enspace \\\\\n\\quad 164 & 164 \\quad \\quad & 1704154 & 1983786 & \\quad \\enspace 6.8 & 7.9 \\quad \\enspace & \\quad 2452.8 & 432.7 \\enspace \\enspace \\\\ \\hline\n\\end{tabular}\n\n\\vspace{0.6cm}\n\n\\end{center}\n\n# Works using MemCNN\n\nMemCNN has recently been used to create reversible GANs for memory-efficient image-to-image translation by @Ouderaa_2019_CVPR. Image-to-image translation considers the problem of mapping both $X \\rightarrow Y$ and $Y \\rightarrow X$ given two image domains $X$ and $Y$ using either paired or unpaired examples. In this work, the CycleGAN [@zhu2017unpaired] model has been enlarged and extended with an invertible core using the reversible block, which they call RevGAN. Since the invertible core is weight tied, training the model for the mapping $X \\rightarrow Y$ automatically trains the model for mapping $Y \\rightarrow X$. They show similar or increased performance of RevGAN with respect to similar non-invertible models like the CycleGAN with less memory overhead during training. The RevGAN model has also been applied to chest CT images [@ouderaa:MIDLAbstract2019a].\n\n# Conclusion\n\nWe have presented MemCNN, a novel PyTorch framework, for creating and applying reversible operations for neural networks. It shows similar accuracy on Cifar-10 and Cifar-100 datasets with the current state-of-the-art method for reversible operations in Tensorflow and provides overall faster training times. The main features of the framework are smooth integration of reversible functions with other non-reversible functions by removing the need for a custom backpropagation and simple wrapping of arbitrary complex non-invertible nonlinear functions. The presented framework is intended to facilitate the study and application of invertible functions in the context of neural networks.\n\n# Acknowledgements\n\nThis work was supported by research grants from the Netherlands Organization for Scientific Research (NWO), the Netherlands and Canon Medical Systems Corporation, Japan.\n\n# References\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\nSimpleITK\ntorch>=1.0.0\ntorchvision\ntqdm\npathlib2\n"
  },
  {
    "path": "setup.cfg",
    "content": "[bumpversion]\ncurrent_version = 1.5.2\ncommit = True\ntag = True\ntag_name = {new_version}\n\n[bumpversion:file:setup.py]\nsearch = VERSION = '{current_version}'\nreplace = VERSION = '{new_version}'\n\n[bumpversion:file:memcnn/__init__.py]\nsearch = __version__ = '{current_version}'\nreplace = __version__ = '{new_version}'\n\n[bdist_wheel]\nuniversal = 1\n\n[flake8]\nexclude = docs\n\n[aliases]\ntest = pytest\n\n[tool:pytest]\ncollect_ignore = ['setup.py']\n"
  },
  {
    "path": "setup.py",
    "content": "import os\nimport sys\nfrom distutils.core import setup\nfrom setuptools.command.install import install\nfrom setuptools import find_packages\n\n# circleci.py version\nVERSION = '1.5.2'\n\nwith open('README.rst', 'r') as fh:\n    long_description = fh.read().split('Results\\n-------')[0]\n\nwith open('requirements.txt', 'r') as fh:\n    requirements = [e.strip() for e in fh.readlines() if e.strip() != '']\n\n\nclass VerifyVersionCommand(install):\n    \"\"\"Custom command to verify that the git tag matches our version\"\"\"\n    description = 'verify that the git tag matches our version'\n\n    def run(self):\n        tag = os.getenv('CIRCLE_TAG')\n\n        if tag != VERSION:\n            info = \"Git tag: {0} does not match the version of this app: {1}\".format(\n                tag, VERSION\n            )\n            sys.exit(info)\n\n\nsetup(\n    name='memcnn',\n    version=VERSION,\n    author='S.C. van de Leemput',\n    author_email='silvandeleemput@gmail.com',\n    packages=find_packages(),\n    include_package_data=True,\n    scripts=[],\n    url='http://pypi.python.org/pypi/memcnn/',\n    license='LICENSE.txt',\n    description='A PyTorch framework for developing memory efficient deep invertible networks.',\n    long_description=long_description,\n    long_description_content_type='text/x-rst',\n    install_requires=requirements,\n    classifiers=[\n        \"License :: OSI Approved :: MIT License\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.6\",\n        \"Programming Language :: Python :: 3.7\",\n        \"Development Status :: 3 - Alpha\",\n        \"Intended Audience :: Science/Research\",\n        \"Topic :: Scientific/Engineering\",\n        \"Topic :: Scientific/Engineering :: Medical Science Apps.\",\n        \"Topic :: Scientific/Engineering :: Information Analysis\",\n        \"Topic :: Software Development :: Libraries\",\n        \"Operating System :: OS Independent\"\n        ],\n    keywords='memcnn invertible PyTorch',\n    cmdclass={\n        'verify': VerifyVersionCommand,\n    }\n)\n"
  },
  {
    "path": "tox.ini",
    "content": "[tox]\nenvlist={py38}-torch{10,11,14,17,latest},release,docs\nskipsdist=True\n\n[testenv]\npassenv=LC_ALL, LANG, HOME\ncommands=pytest --cov=memcnn --cov-report=html --cov-report=xml --junitxml=test-reports/junit.xml\ndeps=\n    pip==19.1.1\n    numpy\n    SimpleITK\n    tqdm\n    pytest\n    pytest-cov\n    torch14: torch==1.4.0\n    torch14: torchvision==0.5.0\n    torch17: torch==1.7.0\n    torch17: torchvision==0.8.1\n    torchlatest: torch\n    torchlatest: torchvision\n\n[testenv:release]\ndeps=\n  bumpversion\ncommands=bumpversion --dry-run minor\n\n# generate the sphinx doc\n[testenv:docs]\nbasepython=python\nchangedir=docs\ndeps=-rdocsRequirements.txt\ncommands=sphinx-build -b linkcheck -b html -d {envtmpdir}/doctrees . {envtmpdir}/html\n"
  }
]