Repository: silvandeleemput/memcnn Branch: master Commit: 439ac2673093 Files: 88 Total size: 235.1 KB Directory structure: gitextract_1awyh7rj/ ├── .circleci/ │ └── config.yml ├── .coveragerc ├── .github/ │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── ISSUE_TEMPLATE.md │ └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .readthedocs.yml ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE.txt ├── MANIFEST.in ├── README.rst ├── bandit.yml ├── devRequirements.txt ├── docker/ │ └── Dockerfile ├── docs/ │ ├── Makefile │ ├── authors.rst │ ├── conf.py │ ├── contributing.rst │ ├── history.rst │ ├── index.rst │ ├── installation.rst │ ├── make.bat │ ├── modules.rst │ ├── readme.rst │ ├── usage.rst │ └── usage_experiments.rst ├── docsRequirements.txt ├── memcnn/ │ ├── .editorconfig │ ├── __init__.py │ ├── config/ │ │ ├── __init__.py │ │ ├── config.json.example │ │ ├── experiments.json │ │ └── tests/ │ │ ├── __init__.py │ │ └── test_config.py │ ├── data/ │ │ ├── __init__.py │ │ ├── cifar.py │ │ ├── sampling.py │ │ └── tests/ │ │ ├── __init__.py │ │ ├── test_cifar.py │ │ └── test_sampling.py │ ├── examples/ │ │ ├── minimal.py │ │ └── test_examples.py │ ├── experiment/ │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── manager.py │ │ └── tests/ │ │ ├── __init__.py │ │ ├── test_factory.py │ │ └── test_manager.py │ ├── models/ │ │ ├── __init__.py │ │ ├── additive.py │ │ ├── affine.py │ │ ├── resnet.py │ │ ├── revop.py │ │ └── tests/ │ │ ├── __init__.py │ │ ├── test_amp.py │ │ ├── test_couplings.py │ │ ├── test_is_invertible_module.py │ │ ├── test_memory_saving.py │ │ ├── test_models.py │ │ ├── test_multi.py │ │ ├── test_resnet.py │ │ ├── test_revop.py │ │ └── test_split_dim.py │ ├── train.py │ ├── trainers/ │ │ ├── __init__.py │ │ ├── classification.py │ │ └── tests/ │ │ ├── __init__.py │ │ ├── resources/ │ │ │ └── experiments.json │ │ ├── test_classification.py │ │ └── test_train.py │ └── utils/ │ ├── __init__.py │ ├── log.py │ ├── loss.py │ ├── stats.py │ └── tests/ │ ├── __init__.py │ ├── test_log.py │ ├── test_loss.py │ └── test_stats.py ├── paper/ │ ├── README │ ├── paper.bib │ └── paper.md ├── requirements.txt ├── setup.cfg ├── setup.py └── tox.ini ================================================ FILE CONTENTS ================================================ ================================================ FILE: .circleci/config.yml ================================================ version: 2.1 aliases: - &container_python docker: - image: cimg/python:3.8.4 # primary container for the build job - &run_task_install_tox_dependencies run: name: install tox dependencies command: | sudo apt-get update sudo apt install -y build-essential libssl-dev libpython-dev python python-pip sudo -H pip install --upgrade pip tox virtualenv orbs: codecov: codecov/codecov@1.0.4 jobs: testing: parameters: tests: type: string default: py38-torch10,py38-torch11,py38-torch14,py38-torch17 <<: *container_python steps: - checkout - *run_task_install_tox_dependencies - run: name: execute pytests << parameters.tests >> no_output_timeout: 30m command: | mkdir test-reports tox -e << parameters.tests >> - codecov/upload: flags: backend,unittest - store_artifacts: path: htmlcov - store_test_results: path: test-reports - codecov/upload: file: coverage/*.json flags: frontend builddocs: <<: *container_python steps: - checkout - *run_task_install_tox_dependencies - run: name: build the sphinx documentation command: | tox -e docs conda_deploy: parameters: versions: type: string default: "3.7 3.8" docker: - image: continuumio/miniconda3 steps: - checkout - run: name: install conda dependencies command: | conda install conda-build anaconda-client conda-verify -y - run: name: generate skeleton file from PyPI and complete recipe command: | cd ~ conda skeleton pypi memcnn cd memcnn python -c "f = open('meta.yaml', 'r'); data = f.read(); f.close(); data=data.replace(' torch ', ' pytorch ').replace('your-github-id-here', 'silvandeleemput').replace('pillow\n - python', 'pillow\n - pip\n - python').replace(' pip', ' pip >=18.0'); f = open('meta.yaml', 'w'); f.write(data); f.close();" cat ~/memcnn/meta.yaml - run: name: build binary artifacts for python versions << parameters.versions >> no_output_timeout: 30m command: | cd ~/memcnn PYTHON_VERSIONS=( << parameters.versions >> ) for i in "${PYTHON_VERSIONS[@]}" do echo $i conda-build -c conda-forge -c simpleitk -c pytorch --numpy 1.15.1 --python $i . done - run: name: upload binary artifacts for all platforms to anaconda cloud command: | anaconda login --user=silvandeleemput --password=$CONDA_PASSWORD find /opt/conda/conda-bld/ -name *.tar.bz2 | while read file do echo $file anaconda upload $file --skip-existing --all done deploy: docker: - image: cimg/python:3.8.4 steps: - checkout - restore_cache: key: v1-dependency-cache-{{ checksum "setup.py" }} - run: name: install python dependencies command: | python3 -m venv venv . venv/bin/activate pip install --upgrade pip pip install pylint doc8 coverage codecov twine pip install -e . - save_cache: key: v1-dependency-cache-{{ checksum "setup.py" }} paths: - "venv" - run: name: verify git tag vs. version command: | python3 -m venv venv . venv/bin/activate python setup.py verify - run: name: init .pypirc command: | echo -e "[pypi]" >> ~/.pypirc echo -e "username = Sil" >> ~/.pypirc echo -e "password = $PYPI_PASSWORD" >> ~/.pypirc - run: name: createpackages command: | python setup.py sdist python setup.py bdist_wheel - run: name: upload to pypi command: | . venv/bin/activate twine upload dist/* - run: name: trigger docker hub master branch build command: | curl -H "Content-Type: application/json" --data '{"source_type": "Branch", "source_name": "master"}' -X POST $DOCKER_TRIGGER_URL - run: name: trigger docker hub latest tag build command: | curl -H "Content-Type: application/json" --data '{"source_type": "Tag", "source_name": "'"$CIRCLE_TAG"'"}' -X POST $DOCKER_TRIGGER_URL workflows: version: 2 build_test_and_deploy: jobs: - testing: name: testing_py38_torch14 tests: py38-torch14 filters: tags: only: /.*/ - testing: name: testing_py38_torch17 tests: py38-torch17 filters: tags: only: /.*/ - deploy: requires: - testing_py38_torch14 - testing_py38_torch17 filters: tags: only: /[0-9]+(\.[0-9]+)*/ branches: ignore: /.*/ - conda_deploy: name: conda_deploy_py37 requires: - deploy versions: "3.7" filters: tags: only: /[0-9]+(\.[0-9]+)*/ branches: ignore: /.*/ - conda_deploy: name: conda_deploy_py38 requires: - deploy versions: "3.8" filters: tags: only: /[0-9]+(\.[0-9]+)*/ branches: ignore: /.*/ ================================================ FILE: .coveragerc ================================================ [run] omit = env/* venv/* tests/* setup.py */tests/*.py source = . ================================================ FILE: .github/CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at silvandeleemput@gmail.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at For answers to common questions about this code of conduct, see [homepage]: https://www.contributor-covenant.org ================================================ FILE: .github/CONTRIBUTING.md ================================================ # Contributing Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. The latest information about how to contribute to MemCNN can be found here: ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ * MemCNN version: * PyTorch version: * Python version: * Operating System: ### Description Describe what you were trying to get done. Tell us what happened, what went wrong, and what you expected to happen. ### What I Did ``` Paste the command(s) you ran and the output. If there was a crash, please include the traceback here. ``` ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ### What was the problem? {Please write here} ### How this PR fixes the problem? {Please write here} ### Check lists (check `x` in `[ ]` of list items) - [ ] Test passed - [ ] Coding style (indentation, etc) ### Additional Comments (if any) {Please write here} ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv venv/ ENV/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # remove memcnn config /memcnn/config/config.json # PyCharm configs /.idea ================================================ FILE: .readthedocs.yml ================================================ # .readthedocs.yml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Build documentation with MkDocs #mkdocs: # configuration: mkdocs.yml # Optionally build your docs in additional formats such as PDF and ePub formats: - htmlzip - epub # Optionally set the version of Python and requirements required to build your docs python: version: 3.7 install: - requirements: docsRequirements.txt ================================================ FILE: AUTHORS.rst ================================================ ======= Credits ======= Development Lead ---------------- * Sil van de Leemput Contributors ------------ * Tycho van der Ouderaa * Jonas Teuwen * Bram van Ginneken * Rashindra Manniesing ================================================ FILE: CONTRIBUTING.rst ================================================ .. highlight:: shell ============ Contributing ============ Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. You can contribute in many ways: Types of Contributions ---------------------- Report Bugs ~~~~~~~~~~~ Report bugs at https://github.com/silvandeleemput/memcnn/issues. If you are reporting a bug, please include: * Your operating system name and version. * Any details about your local setup that might be helpful in troubleshooting. * Detailed steps to reproduce the bug. Fix Bugs ~~~~~~~~ Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants to implement it. Implement Features ~~~~~~~~~~~~~~~~~~ Look through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is open to whoever wants to implement it. Write Documentation ~~~~~~~~~~~~~~~~~~~ MemCNN could always use more documentation, whether as part of the official MemCNN docs, in docstrings, or even on the web in blog posts, articles, and such. Submit Feedback ~~~~~~~~~~~~~~~ The best way to send feedback is to file an issue at https://github.com/silvandeleemput/memcnn/issues. If you are proposing a feature: * Explain in detail how it would work. * Keep the scope as narrow as possible, to make it easier to implement. * Remember that this is a volunteer-driven project, and that contributions are welcome :) Get Started! ------------ Ready to contribute? Here's how to set up `memcnn` for local development. 1. Fork the `memcnn` repo on GitHub. 2. Clone your fork locally:: $ git clone git@github.com:your_name_here/memcnn.git 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: $ mkvirtualenv memcnn $ cd memcnn/ $ python setup.py develop 4. Create a branch for local development:: $ git checkout -b name-of-your-bugfix-or-feature Now you can make your changes locally. 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: $ flake8 memcnn tests $ python setup.py test or py.test $ tox To get flake8 and tox, just pip install them into your virtualenv. 6. Commit your changes and push your branch to GitHub:: $ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature 7. Submit a pull request through the GitHub website. Pull Request Guidelines ----------------------- Before you submit a pull request, check that it meets these guidelines: 1. The pull request should include tests. 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. 3. The pull request should work for Python 2.7, 3.5+, and for PyPy. Check through tox that all the tests pass for all supported Python versions. Tips ---- To run a subset of tests:: $ pytest memcnn/memcnn/models/tests/test_revop.py To run a specific test:: $ pytest memcnn/memcnn/models/tests/test_revop.py::test_reversible_block_fwd_bwd Deploying --------- A reminder for the maintainers on how to deploy. Make sure all your changes are committed (including an entry in HISTORY.rst). Then run:: $ bumpversion patch # possible: major / minor / patch $ git push $ git push origin CircleCI will then deploy to PyPI if tests pass. ================================================ FILE: HISTORY.rst ================================================ ======= History ======= 1.5.2 (2023-05-10) ------------------ * Fixed issue with CIFAR data loaders not being able to be pickled because of local Lambda operations * Fixed CI issues, disabled PyTorch v1.0, v1.1, and latest checks 1.5.1 (2021-08-07) ------------------ * Added support for 2-dimensional inputs for AffineAdapterSigmoid * Fixed CI issues 1.5.0 (2020-11-24) ------------------ * Added support for mixed-precision training using torch.cuda.amp (inputs fixed to float32 for now) * Added support for PyTorch v1.7 * Dropped support for PyTorch < v1.0 and Python 2 * Removed the version limit for Pillow in the requirements 1.4.0 (2020-06-05) ------------------ * Added support for splitting on arbitrary dimensions to the Couplings. Big thanks to ClashLuke for the PR * Added a preserve_rng_state option to the InvertibleModuleWrapper 1.3.2 (2020-03-05) ------------------ * Improved InvertibleModuleWrapper * Added support for multi input/output invertible operations! Big thanks to Christian Etmann for the PR * Improved the is_invertible_module test * Added multi input/output checks * Fixed random seed per default * Additional warning checks have been added 1.3.1 (2020-03-02) ------------------ * HOTFIX InvertibleCheckpointFunction uses ref_count for inputs as well to avoid memory spikes 1.3.0 (2020-03-01) ------------------ * Updated underlying mechanics for the InvertibleModuleWrapper * Hooks have been replaced by a torch.autograd.Function called InvertibleCheckpointFunction * Identity functions are now supported * Reported unstable memory behavior should be fixed now when using the InvertibleModuleWrapper! * Minor changes to test suite 1.2.1 (2020-02-24) ------------------ * Added InvertibleModuleWrapper support to is_invertible_module test 1.2.0 (2020-01-19) ------------------ * Replaced TensorBoard logging with simple json file logging which removed the cumbersome TensorBoard and TensorFlow dependencies * Updated the Dockerfile for Python37 and PyTorch 1.4.0 * Updated the CI tests Py36 versions to Py37, also added a new CI test for PyTorch 1.4.0 1.1.1 (2020-01-11) ------------------ * Fixed some versions in the requirements for TensorFlow and Pillow to avoid errors and segfaults * The module auto documentation has been updated for the new API changes 1.1.0 (2019-12-15) ------------------ * A complete refactor of MemCNN with changes to the API * Factored out the code responsible for the memory savings in a separate InvertibleModuleWrapper and reimplemented it using hooks * The InvertibleModuleWrapper allows for arbitrary invertible functions now (not just the additive and affine couplings) * The AdditiveBlock and AffineBlock have been refactored to AdditiveCoupling and AffineCoupling * The ReveribleBlock is now deprecated * The documentation and examples have been updated for the new API changes 1.0.1 (2019-12-08) ------------------ * Bug fixes related to SummaryIterator import in Tensorflow 2 (location of summary_iterator has changed in TensorFlow) * Bug fixes related to NSamplesRandomSampler nsamples attribute (would crash if no-gpu and numpy.int were given) 1.0.0 (2019-07-28) ------------------ * Major release for completing the JOSS review: * Anaconda cloud and codacy code quality CI * Updated/improved documentation 0.3.5 (2019-07-28) ------------------ * Added CI for anaconda cloud * Documented conda installation steps * Minor test release for testing CI build 0.3.4 (2019-07-26) ------------------ * Performed changes recommended by JOSS reviewers: * Added requirements.txt to manifest.in * Added codacy code quality integration * Improved documentation * Setup proper github contribution templates 0.3.3 (2019-07-10) ------------------ * Added docker build triggers to CI * Finalized JOSS paper.md 0.3.2 (2019-07-10) ------------------ * Added docker build shield * Fixed a bug with device agnostic tensor generation for loss.py * Code cleanup resnet.py * Added examples to distribution with pytests * Improved documentation 0.3.1 (2019-07-09) ------------------ * Added experiments.json and config.json.example data files to the distribution * Fixed documentation issues with mock modules 0.3.0 (2019-07-09) ------------------ * Updated major bug in distribution setup.py * Removed older releases due to bug * Added the ReversibleBlock at the module level * Splitted keep_input into keep_input and keep_input_inverse 0.2.1 (2019-06-06 - Removed) ---------------------------- * Patched the memory saving tests 0.2.0 (2019-05-28 - Removed) ---------------------------- * Minor update with better coverage and affine coupling support 0.1.0 (2019-05-24 - Removed) ---------------------------- * First release on PyPI ================================================ FILE: LICENSE.txt ================================================ MIT License Copyright (c) 2018 Sil C. van de Leemput Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include AUTHORS.rst include CONTRIBUTING.rst include HISTORY.rst include LICENSE.txt include README.rst include requirements.txt include devRequirements.txt include docsRequirements.txt recursive-include tests * recursive-exclude * __pycache__ recursive-exclude * *.py[co] recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif recursive-include memcnn/config config.json.example experiments.json ================================================ FILE: README.rst ================================================ ====== MemCNN ====== .. image:: https://img.shields.io/badge/maintenance-unmaintained-red.svg :alt: Unmaintained! :target: https://github.com/silvandeleemput/memcnn .. image:: https://img.shields.io/circleci/build/github/silvandeleemput/memcnn/master.svg :alt: CircleCI - Status master branch :target: https://circleci.com/gh/silvandeleemput/memcnn/tree/master .. image:: https://readthedocs.org/projects/memcnn/badge/?version=latest :alt: Documentation - Status master branch :target: https://memcnn.readthedocs.io/en/latest/?badge=latest .. image:: https://img.shields.io/codacy/grade/95de32e0d7c54d038611da47e9f0948b/master.svg :alt: Codacy - Branch grade :target: https://app.codacy.com/project/silvandeleemput/memcnn/dashboardgit .. image:: https://img.shields.io/codecov/c/gh/silvandeleemput/memcnn/master.svg :alt: Codecov - Status master branch :target: https://codecov.io/gh/silvandeleemput/memcnn .. image:: https://img.shields.io/pypi/v/memcnn.svg :alt: PyPI - Latest release :target: https://pypi.python.org/pypi/memcnn .. image:: https://img.shields.io/conda/vn/silvandeleemput/memcnn?label=anaconda :alt: Conda - Latest release :target: https://anaconda.org/silvandeleemput/memcnn .. image:: https://img.shields.io/pypi/implementation/memcnn.svg :alt: PyPI - Implementation :target: https://pypi.python.org/pypi/memcnn .. image:: https://img.shields.io/pypi/pyversions/memcnn.svg :alt: PyPI - Python version :target: https://pypi.python.org/pypi/memcnn .. image:: https://img.shields.io/github/license/silvandeleemput/memcnn.svg :alt: GitHub - Repository license :target: https://github.com/silvandeleemput/memcnn/blob/master/LICENSE.txt .. image:: http://joss.theoj.org/papers/10.21105/joss.01576/status.svg :alt: JOSS - DOI :target: https://doi.org/10.21105/joss.01576 A `PyTorch `__ framework for developing memory-efficient invertible neural networks. * Free software: `MIT license `__ (please cite our work if you use it) * Documentation: https://memcnn.readthedocs.io. * Installation: https://memcnn.readthedocs.io/en/latest/installation.html ⚠️ Project Status: Unmaintained This repository is no longer actively maintained. The code is kept available for reference and historical purposes, but no new features, bug fixes, or support should be expected. If you find the project useful, feel free to fork it and continue development. Features -------- * Enable memory savings during training by wrapping arbitrary invertible PyTorch functions with the `InvertibleModuleWrapper` class. * Simple toggling of memory saving by setting the `keep_input` property of the `InvertibleModuleWrapper`. * Turn arbitrary non-linear PyTorch functions into invertible versions using the `AdditiveCoupling` or the `AffineCoupling` classes. * Training and evaluation code for reproducing RevNet experiments using MemCNN. * CI tests for Python v3.7 and torch v1.0, v1.1, v1.4 and v1.7 with good code coverage. Examples -------- Creating an AdditiveCoupling with memory savings ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python import torch import torch.nn as nn import memcnn # define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d class ExampleOperation(nn.Module): def __init__(self, channels): super(ExampleOperation, self).__init__() self.seq = nn.Sequential( nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), padding=1), nn.BatchNorm2d(num_features=channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.seq(x) # generate some random input data (batch_size, num_channels, y_elements, x_elements) X = torch.rand(2, 10, 8, 8) # application of the operation(s) the normal way model_normal = ExampleOperation(channels=10) model_normal.eval() Y = model_normal(X) # turn the ExampleOperation invertible using an additive coupling invertible_module = memcnn.AdditiveCoupling( Fm=ExampleOperation(channels=10 // 2), Gm=ExampleOperation(channels=10 // 2) ) # test that it is actually a valid invertible module (has a valid inverse method) assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape) # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True) # by default the module is set to training, the following sets this to evaluation # note that this is required to pass input tensors to the model with requires_grad=False (inference only) invertible_module_wrapper.eval() # test that the wrapped module is also a valid invertible module assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape) # compute the forward pass using the wrapper Y2 = invertible_module_wrapper.forward(X) # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2 X2 = invertible_module_wrapper.inverse(Y2) # test that the input and approximation are similar assert torch.allclose(X, X2, atol=1e-06) Run PyTorch Experiments ----------------------- After installing MemCNN run: .. code:: bash python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda] * Available values for ``DATASET`` are ``cifar10`` and ``cifar100``. * Available values for ``MODEL`` are ``resnet32``, ``resnet110``, ``resnet164``, ``revnet38``, ``revnet110``, ``revnet164`` * Use the ``--fresh`` flag to remove earlier experiment results. * Use the ``--no-cuda`` flag to train on the CPU rather than the GPU through CUDA. Datasets are automatically downloaded if they are not available. When using Python 3.* replace the ``python`` directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use ``python3.6``. When MemCNN was installed using `pip` or from sources you might need to setup a configuration file before running this command. Read the corresponding section about how to do this here: https://memcnn.readthedocs.io/en/latest/installation.html Results ------- TensorFlow results were obtained from `the reversible residual network `__ running the code from their `GitHub `__. The PyTorch results listed were recomputed on June 11th 2018, and differ from the results in the ICLR paper. The Tensorflow results are still the same. Prediction accuracy ^^^^^^^^^^^^^^^^^^^ +------------+------------------------+--------------------------+----------------------+----------------------+ | | Cifar-10 | Cifar-100 | +------------+------------------------+--------------------------+----------------------+----------------------+ | Model | Tensorflow | PyTorch | Tensorflow | PyTorch | +============+========================+==========================+======================+======================+ | resnet-32 | 92.74 | 92.86 | 69.10 | 69.81 | +------------+------------------------+--------------------------+----------------------+----------------------+ | resnet-110 | 93.99 | 93.55 | 73.30 | 72.40 | +------------+------------------------+--------------------------+----------------------+----------------------+ | resnet-164 | 94.57 | 94.80 | 76.79 | 76.47 | +------------+------------------------+--------------------------+----------------------+----------------------+ | revnet-38 | 93.14 | 92.80 | 71.17 | 69.90 | +------------+------------------------+--------------------------+----------------------+----------------------+ | revnet-110 | 94.02 | 94.10 | 74.00 | 73.30 | +------------+------------------------+--------------------------+----------------------+----------------------+ | revnet-164 | 94.56 | 94.90 | 76.39 | 76.90 | +------------+------------------------+--------------------------+----------------------+----------------------+ Training time (hours : minutes) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +------------+------------------------+--------------------------+----------------------+----------------------+ | | Cifar-10 | Cifar-100 | +------------+------------------------+--------------------------+----------------------+----------------------+ | Model | Tensorflow | PyTorch | Tensorflow | PyTorch | +============+========================+==========================+======================+======================+ | resnet-32 | 2:04 | 1:51 | 1:58 | 1:51 | +------------+------------------------+--------------------------+----------------------+----------------------+ | resnet-110 | 4:11 | 2:51 | 6:44 | 2:39 | +------------+------------------------+--------------------------+----------------------+----------------------+ | resnet-164 | 11:05 | 4:59 | 10:59 | 3:45 | +------------+------------------------+--------------------------+----------------------+----------------------+ | revnet-38 | 2:17 | 2:09 | 2:20 | 2:16 | +------------+------------------------+--------------------------+----------------------+----------------------+ | revnet-110 | 6:59 | 3:42 | 7:03 | 3:50 | +------------+------------------------+--------------------------+----------------------+----------------------+ | revnet-164 | 13:09 | 7:21 | 13:12 | 7:17 | +------------+------------------------+--------------------------+----------------------+----------------------+ Memory consumption of model training in PyTorch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+ | Layers | Parameters | Parameters (MB) | Activations (MB) | +------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+ | ResNet | RevNet | ResNet | RevNet | ResNet | RevNet | ResNet | RevNet | +========================+==========================+======================+======================+========================+==========================+======================+======================+ | 32 | 38 | 466906 | 573994 | 1.9 | 2.3 | 238.6 | 85.6 | +------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+ | 110 | 110 | 1730714 | 1854890 | 6.8 | 7.3 | 810.7 | 85.7 | +------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+ | 164 | 164 | 1704154 | 1983786 | 6.8 | 7.9 | 2452.8 | 432.7 | +------------------------+--------------------------+----------------------+----------------------+------------------------+--------------------------+----------------------+----------------------+ The `ResNet` model is the conventional Residual Network implementation in PyTorch, while the RevNet model uses the `memcnn.InvertibleModuleWrapper` to achieve memory savings. Works using MemCNN ------------------ * `MemCNN: a Framework for Developing Memory Efficient Deep Invertible Networks `__ by Sil C. van de Leemput et al. * `Reversible GANs for Memory-efficient Image-to-Image Translation `__ by Tycho van der Ouderaa et al. * `Chest CT Super-resolution and Domain-adaptation using Memory-efficient 3D Reversible GANs `__ by Tycho van der Ouderaa et al. * `iUNets: Fully invertible U-Nets with Learnable Up- and Downsampling `__ by Christian Etmann et al. Citation -------- Sil C. van de Leemput, Jonas Teuwen, Bram van Ginneken, and Rashindra Manniesing. MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks. Journal of Open Source Software, 4, 1576, http://dx.doi.org/10.21105/joss.01576, 2019. If you use our code, please cite: .. code:: bibtex @article{vandeLeemput2019MemCNN, journal = {Journal of Open Source Software}, doi = {10.21105/joss.01576}, issn = {2475-9066}, number = {39}, publisher = {The Open Journal}, title = {MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks}, url = {http://dx.doi.org/10.21105/joss.01576}, volume = {4}, author = {Sil C. {van de} Leemput and Jonas Teuwen and Bram {van} Ginneken and Rashindra Manniesing}, pages = {1576}, date = {2019-07-30}, year = {2019}, month = {7}, day = {30}, } ================================================ FILE: bandit.yml ================================================ skips: ['B101'] ================================================ FILE: devRequirements.txt ================================================ -r requirements.txt bumpversion wheel watchdog flake8 tox coverage Sphinx twine pytest pytest-cov pytest-runner ================================================ FILE: docker/Dockerfile ================================================ FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 RUN apt-get update && apt-get install -y \ software-properties-common \ && \ rm -rf /var/lib/apt/lists/* RUN add-apt-repository ppa:deadsnakes/ppa && apt-get update RUN apt-get install -y \ git \ python3.7-dev \ python3-pip \ sudo \ && rm -rf /var/lib/apt/lists/* # Add user with valid passwrd RUN useradd -ms /bin/bash user RUN (echo user ; echo user) | passwd user # Configure sudo RUN usermod -a -G sudo user # Install necessary python libraries RUN python3.7 -m pip install pip --upgrade RUN python3.7 -m pip install pip install torch===1.7.0 torchvision===0.8.1 torchaudio===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html RUN python3.7 -m pip install memcnn RUN python3.7 -m pip install pytest # Set MemCNN config file for user environement RUN python3.7 -c "import os, shutil, memcnn; path=os.path.join(os.path.dirname(memcnn.__file__), 'config'); shutil.copy(os.path.join(path, 'config.json.example'), os.path.join(path, 'config.json'));" # Change user and prepare user data folders USER user WORKDIR /home/user RUN mkdir data RUN mkdir experiments ENTRYPOINT /bin/bash ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = python -msphinx SPHINXPROJ = memcnn SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/authors.rst ================================================ .. include:: ../AUTHORS.rst ================================================ FILE: docs/conf.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # # memcnn documentation build configuration file, created by # sphinx-quickstart on Fri Jun 9 13:47:02 2017. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. # If extensions (or modules to document with autodoc) are in another # directory, add these directories to sys.path here. If the directory is # relative to the documentation root, use os.path.abspath to make it # absolute, like shown here. # import os import sys sys.path.insert(0, os.path.abspath('..')) # this gets the memcnn_version without importing memcnn and causing troubles with mock later on with open(os.path.join(os.path.dirname(__file__), '..', 'memcnn', '__init__.py'), 'r') as f: memcnn_version = [line.split("'")[1] for line in f.readlines() if '__version__' in line][0] # -- General configuration --------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The master toctree document. master_doc = 'index' # Napoleon settings napoleon_google_docstring = False napoleon_numpy_docstring = True # autodoc settings autoclass_content = 'both' autodoc_mock_imports = ['torch', 'torch.nn', 'numpy', 'torchvision'] intersphinx_mapping = { 'python': ('https://docs.python.org/', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'torch': ('https://pytorch.org/docs/stable/', None) } mathjax_path = "https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" mathjax_config = { 'extensions': ['tex2jax.js'], 'jax': ['input/TeX', 'output/HTML-CSS'], } # General information about the project. project = u'MemCNN' copyright = u"2019, Sil van de Leemput" author = u"Sil van de Leemput" # The version info for the project you're documenting, acts as replacement # for |version| and |release|, also used in various other places throughout # the built documents. # # The short X.Y version. version = memcnn_version # The full version, including alpha/beta/rc tags. release = memcnn_version # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False # -- Options for HTML output ------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = [] # -- Options for HTMLHelp output --------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'memcnndoc' # -- Options for LaTeX output ------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass # [howto, manual, or own class]). latex_documents = [ (master_doc, 'memcnn.tex', u'MemCNN Documentation', u'Sil van de Leemput', 'manual'), ] # -- Options for manual page output ------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'memcnn', u'MemCNN Documentation', [author], 1) ] # -- Options for Texinfo output ---------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'memcnn', u'MemCNN Documentation', author, 'memcnn', 'A PyTorch framework for developing memory efficient deep invertible networks.', 'Miscellaneous'), ] ================================================ FILE: docs/contributing.rst ================================================ .. include:: ../CONTRIBUTING.rst ================================================ FILE: docs/history.rst ================================================ .. include:: ../HISTORY.rst ================================================ FILE: docs/index.rst ================================================ Welcome to MemCNN's documentation! ====================================== .. toctree:: :maxdepth: 2 :caption: Contents: readme installation usage modules contributing authors history Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/installation.rst ================================================ .. highlight:: shell ============ Installation ============ Requirements ------------ - `Python `__ 3.6+ - `PyTorch `__ 1.0+ (CUDA support recommended) Stable release -------------- These are the preferred methods to install MemCNN, as they will always install the most recent stable release. PyPi ^^^^ To install MemCNN using the Python package manager, run this command in your terminal: .. code-block:: console $ pip install memcnn If you don't have `pip`_ installed, this `Python installation guide`_ can guide you through the process. .. _pip: https://pip.pypa.io .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ Anaconda ^^^^^^^^ To install MemCNN using Anaconda, run this command in your terminal: .. code-block:: console $ conda install -c silvandeleemput -c pytorch -c simpleitk -c conda-forge memcnn If you don't have `conda`_ installed, this `Anaconda installation guide`_ can guide you through the process. .. _conda: https://www.anaconda.com/ .. _Anaconda installation guide: https://docs.conda.io/projects/conda/en/latest/user-guide/install/ From sources ------------ The sources for MemCNN can be downloaded from the `Github repo`_. You can either clone the public repository: .. code-block:: console $ git clone git://github.com/silvandeleemput/memcnn Or download the `tarball`_: .. code-block:: console $ curl -OL https://github.com/silvandeleemput/memcnn/tarball/master Once you have a copy of the source, you can install it with: .. code-block:: console $ python setup.py install .. _Github repo: https://github.com/silvandeleemput/memcnn .. _tarball: https://github.com/silvandeleemput/memcnn/tarball/master Using docker ------------ MemCNN has several pre-build docker images that are hosted on dockerhub. You can directly pull these and to have a working environment for running the experiments. Run image from repository ^^^^^^^^^^^^^^^^^^^^^^^^^ Run the latest docker build of MemCNN from the repository (automatically pulls the image): .. code-block:: console $ docker run --shm-size=4g --runtime=nvidia -it silvandeleemput/memcnn:latest For ``--runtime=nvidia`` to work `nvidia-docker `__ must be installed on your system. It can be omitted but this will drop GPU training support. This will open a preconfigured bash shell, which is correctly configured to run the experiments. The latest version has Ubuntu 18.04 and Python 3.7 installed. By default, the datasets and experimental results will be put inside the created docker container under: ``\home\user\data`` and ``\home\user\experiments`` respectively. Build image from source ^^^^^^^^^^^^^^^^^^^^^^^ Requirements: - NVIDIA graphics card and the proper NVIDIA-drivers on your system The following bash commands will clone this repository and do a one-time build of the docker image with the right environment installed: .. code-block:: console $ git clone https://github.com/silvandeleemput/memcnn.git $ docker build ./memcnn/docker --tag=silvandeleemput/memcnn:latest After the one-time install on your machine, the docker image can be invoked using the same commands as listed above. Experiment configuration file ----------------------------- To run the experiments, MemCNN requires setting up a configuration file containing locations to put the data files. This step is not necessary for the docker builds. The configuration file ``config.json`` goes in the ``/memcnn/config/`` directory of the library and should be formatted as follows: .. code:: json { "data_dir": "/home/user/data", "results_dir": "/home/user/experiments" } * data_dir : location for storing the input training datasets * results_dir : location for storing the experiment files during training Change the data paths to your liking. If you are unsure where MemCNN and/or the configuration file is located on your machine run: .. code-block:: console $ python -m memcnn.train If the configuration file is not setup correctly, this command should give the user the correct path to the configuration file. Next, create/edit the file at the given location. ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=python -msphinx ) set SOURCEDIR=. set BUILDDIR=_build set SPHINXPROJ=memcnn if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The Sphinx module was not found. Make sure you have Sphinx installed, echo.then set the SPHINXBUILD environment variable to point to the full echo.path of the 'sphinx-build' executable. Alternatively you may add the echo.Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd ================================================ FILE: docs/modules.rst ================================================ ======= Modules ======= .. automodule:: memcnn :members: is_invertible_module .. autoclass:: memcnn.InvertibleModuleWrapper :members: forward, inverse .. autoclass:: memcnn.AdditiveCoupling :members: forward, inverse .. autoclass:: memcnn.AffineCoupling :members: forward, inverse .. autoclass:: memcnn.AffineAdapterNaive .. autoclass:: memcnn.AffineAdapterSigmoid .. autoclass:: memcnn.ReversibleBlock :members: forward, inverse ================================================ FILE: docs/readme.rst ================================================ .. include:: ../README.rst ================================================ FILE: docs/usage.rst ================================================ ===== Usage ===== To use MemCNN in a project:: import memcnn Examples -------- Creating an AdditiveCoupling with memory savings ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. code:: python import torch import torch.nn as nn import memcnn # define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d class ExampleOperation(nn.Module): def __init__(self, channels): super(ExampleOperation, self).__init__() self.seq = nn.Sequential( nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), padding=1), nn.BatchNorm2d(num_features=channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.seq(x) # generate some random input data (batch_size, num_channels, y_elements, x_elements) X = torch.rand(2, 10, 8, 8) # application of the operation(s) the normal way model_normal = ExampleOperation(channels=10) model_normal.eval() Y = model_normal(X) # turn the ExampleOperation invertible using an additive coupling invertible_module = memcnn.AdditiveCoupling( Fm=ExampleOperation(channels=10 // 2), Gm=ExampleOperation(channels=10 // 2) ) # test that it is actually a valid invertible module (has a valid inverse method) assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape) # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True) # by default the module is set to training, the following sets this to evaluation # note that this is required to pass input tensors to the model with requires_grad=False (inference only) invertible_module_wrapper.eval() # test that the wrapped module is also a valid invertible module assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape) # compute the forward pass using the wrapper Y2 = invertible_module_wrapper.forward(X) # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2 X2 = invertible_module_wrapper.inverse(Y2) # test that the input and approximation are similar assert torch.allclose(X, X2, atol=1e-06) Run PyTorch Experiments ----------------------- .. include:: ./usage_experiments.rst ================================================ FILE: docs/usage_experiments.rst ================================================ After installing MemCNN run: .. code:: bash python -m memcnn.train [MODEL] [DATASET] [--fresh] [--no-cuda] * Available values for ``DATASET`` are ``cifar10`` and ``cifar100``. * Available values for ``MODEL`` are ``resnet32``, ``resnet110``, ``resnet164``, ``revnet38``, ``revnet110``, ``revnet164`` * Use the ``--fresh`` flag to remove earlier experiment results. * Use the ``--no-cuda`` flag to train on the CPU rather than the GPU through CUDA. Datasets are automatically downloaded if they are not available. When using Python 3.* replace the ``python`` directive with the appropriate Python 3 directive. For example when using the MemCNN docker image use ``python3.7``. ================================================ FILE: docsRequirements.txt ================================================ sphinx sphinxcontrib-plantuml sphinxcontrib-ansibleautodoc sphinx_rtd_theme PyYAML mock ================================================ FILE: memcnn/.editorconfig ================================================ # http://editorconfig.org root = true [*] indent_style = space indent_size = 4 trim_trailing_whitespace = true insert_final_newline = true charset = utf-8 end_of_line = lf [*.bat] indent_style = tab end_of_line = crlf [LICENSE] insert_final_newline = false [Makefile] indent_style = tab ================================================ FILE: memcnn/__init__.py ================================================ # -*- coding: utf-8 -*- """Top-level package for MemCNN.""" __author__ = """Sil van de Leemput""" __email__ = 'silvandeleemput@gmail.com' __version__ = '1.5.2' from memcnn.models.revop import ReversibleBlock, InvertibleModuleWrapper, create_coupling, is_invertible_module from memcnn.models.additive import AdditiveCoupling from memcnn.models.affine import AffineCoupling, AffineAdapterNaive, AffineAdapterSigmoid __all__ = [ 'AdditiveCoupling', 'AffineCoupling', 'AffineAdapterNaive', 'AffineAdapterSigmoid', 'InvertibleModuleWrapper', 'ReversibleBlock', 'create_coupling', 'is_invertible_module' ] ================================================ FILE: memcnn/config/__init__.py ================================================ import json import os class Config(dict): def __init__(self, dic=None, verbose=False): super(Config, self).__init__() if dic is None: fname = self.get_filename() if verbose: print("loading default {0}".format(fname)) with open(fname, "r") as f: dic = json.load(f) self.update(dic) @staticmethod def get_filename(): return os.path.join(Config.get_dir(), "config.json") @staticmethod def get_dir(): return os.path.dirname(__file__) ================================================ FILE: memcnn/config/config.json.example ================================================ { "data_dir": "/home/user/data", "results_dir": "/home/user/experiments" } ================================================ FILE: memcnn/config/experiments.json ================================================ { "resnet32": { "data_loader_params": { "batch_size": 100, "max_epoch": 80000 }, "model": "memcnn.models.resnet.ResNet", "model_params": { "block":"memcnn.models.resnet.BasicBlock", "layers":[5, 5, 5], "channels_per_layer":[16,16,32,64], "strides":[1, 1, 2, 2], "init_max_pool":false, "init_kernel_size":3, "batch_norm_fix":false }, "optimizer": "torch.optim.SGD", "optimizer_params": { "lr":0.1, "momentum":0.9, "weight_decay":2e-4 }, "trainer":"memcnn.trainers.classification.train", "trainer_params":{ "loss":"memcnn.utils.loss.CrossEntropyLossTF" } }, "resnet110": { "base": "resnet32", "model_params": { "layers":[18, 18, 18] } }, "resnet164": { "base": "resnet110", "model_params": { "block":"memcnn.models.resnet.Bottleneck" } }, "revnet38": { "base": "resnet32", "model_params": { "layers":[3, 3, 3], "channels_per_layer":[32,32,64,112], "block":"memcnn.models.resnet.RevBasicBlock" } }, "revnet110": { "base": "revnet38", "model_params": { "layers":[9, 9, 9], "channels_per_layer":[32,32,64,128] } }, "revnet164": { "base": "revnet110", "model_params": { "block":"memcnn.models.resnet.RevBottleneck" } }, "cifar10": { "data_loader": "memcnn.data.cifar.get_cifar_data_loaders", "data_loader_params": { "dataset": "torchvision.datasets.CIFAR10", "workers": 16 }, "model_params": { "num_classes":10 } }, "cifar100": { "data_loader": "memcnn.data.cifar.get_cifar_data_loaders", "data_loader_params": { "dataset": "torchvision.datasets.CIFAR100", "workers": 16 }, "model_params": { "num_classes":100 } } } ================================================ FILE: memcnn/config/tests/__init__.py ================================================ ================================================ FILE: memcnn/config/tests/test_config.py ================================================ import unittest import json import os from memcnn.experiment.factory import load_experiment_config, experiment_config_parser from memcnn.config import Config import memcnn.config class ConfigTestCase(unittest.TestCase): class ConfigTest(Config): @staticmethod def get_filename(): return os.path.join(Config.get_dir(), "config.json.example") def setUp(self): self.config = ConfigTestCase.ConfigTest() self.config_fname = os.path.join(os.path.dirname(__file__), "..", "config.json.example") self.experiments_fname = os.path.join(os.path.dirname(__file__), "..", "experiments.json") def load_json_file(fname): with open(fname, 'r') as f: data = json.load(f) return data self.load_json_file = load_json_file def test_loading_main_config(self): self.assertTrue(os.path.exists(self.config.get_filename())) data = self.config self.assertTrue(isinstance(data, dict)) self.assertTrue("data_dir" in data) self.assertTrue("results_dir" in data) def test_loading_experiments_config(self): self.assertTrue(os.path.exists(self.experiments_fname)) data = self.load_json_file(self.experiments_fname) self.assertTrue(isinstance(data, dict)) def test_experiment_configs(self): data = self.load_json_file(self.experiments_fname) config = self.config keys = data.keys() for key in keys: result = load_experiment_config(self.experiments_fname, [key]) self.assertTrue(isinstance(result, dict)) if "dataset" in result: experiment_config_parser(result, config['data_dir']) def test_config_get_filename(self): self.assertEqual(Config.get_filename(), os.path.join(os.path.dirname(memcnn.config.__file__), "config.json")) def test_config_get_dir(self): self.assertEqual(Config.get_dir(), os.path.dirname(memcnn.config.__file__)) def test_verbose(self): ConfigTestCase.ConfigTest(verbose=True) if __name__ == '__main__': unittest.main() ================================================ FILE: memcnn/data/__init__.py ================================================ ================================================ FILE: memcnn/data/cifar.py ================================================ import torch from torch.utils.data import DataLoader import torchvision.transforms as transforms import numpy as np from memcnn.data.sampling import NSamplesRandomSampler import functools def random_crop_transform(x, crop_size=3, img_size=(32, 32)): cz = (crop_size + 1) // 2 x_pad = np.pad(x, ((cz, cz), (cz, cz), (0, 0)), mode='constant') sx, sy = np.random.randint(crop_size + 1), np.random.randint(crop_size + 1) return x_pad[sx:sx + img_size[0], sy:sy + img_size[1], :] def tonumpy_fn(x): return np.array(x.getdata()).reshape(x.size[1], x.size[0], 3) def random_lr_flip_fn(x): return np.copy(x[:, ::-1, :]) if np.random.random() >= 0.5 else x def mean_subtract_fn(x, mean=0): return x.astype(np.float32) - mean def reformat_fn(x): return x.transpose(2, 0, 1).astype(np.float32) def get_cifar_data_loaders(dataset, data_dir, max_epoch, batch_size, workers): train_set = dataset(root=data_dir, train=True, download=True) valid_set = dataset(root=data_dir, train=False, download=True) # calculate mean subtraction img with backwards compatibility for torchvision < 0.2.2 tdata = train_set.train_data if hasattr(train_set, 'train_data') else train_set.data vdata = valid_set.test_data if hasattr(valid_set, 'test_data') else valid_set.data mean_img = np.concatenate((tdata, vdata), axis=0).mean(axis=0) mean_subtract_partial_fn = functools.partial(mean_subtract_fn, mean=mean_img) # define transforms randomcroplambda = transforms.Lambda(random_crop_transform) tonumpy = transforms.Lambda(tonumpy_fn) randomlrflip = transforms.Lambda(random_lr_flip_fn) meansubtraction = transforms.Lambda(mean_subtract_partial_fn) reformat = transforms.Lambda(reformat_fn) totensor = transforms.Lambda(torch.from_numpy) tfs = transforms.Compose([ tonumpy, meansubtraction, randomcroplambda, randomlrflip, reformat, totensor ]) train_set.transform = tfs valid_set.transform = tfs sampler = NSamplesRandomSampler(train_set, max_epoch * batch_size) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=workers, pin_memory=True) val_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True) return train_loader, val_loader ================================================ FILE: memcnn/data/sampling.py ================================================ import torch from torch.utils.data.sampler import Sampler class NSamplesRandomSampler(Sampler): """Samples elements randomly, with replacement, always in blocks all elements of the dataset. Only the remainder will be sampled with less elements. Arguments: data_source (Dataset): dataset to sample from nsamples (int): number of total samples. Note: will always be cast to int """ @property def nsamples(self): return self._nsamples @nsamples.setter def nsamples(self, value): self._nsamples = int(value) def __init__(self, data_source, nsamples): self.data_source = data_source self.nsamples = nsamples def __iter__(self): samples = torch.LongTensor() len_data_source = len(self.data_source) for _ in range(self.nsamples // len_data_source): samples = torch.cat((samples, torch.randperm(len_data_source).long())) if self.nsamples % len_data_source > 0: samples = torch.cat((samples, torch.randperm(self.nsamples % len_data_source).long())) return iter(samples) def __len__(self): return self.nsamples ================================================ FILE: memcnn/data/tests/__init__.py ================================================ ================================================ FILE: memcnn/data/tests/test_cifar.py ================================================ import pytest from memcnn.data.cifar import get_cifar_data_loaders, random_crop_transform import torch.utils.data as data import numpy as np from PIL import Image @pytest.mark.parametrize('crop_size,img_size', [(4, (32, 32)), (0, (32, 32))]) def test_random_crop_transform(crop_size, img_size): np.random.seed(42) img = np.random.random((img_size[0], img_size[1], 3)) imgres = random_crop_transform(img, crop_size, img_size) assert imgres.shape == img.shape assert imgres.dtype == img.dtype if crop_size == 0: assert np.array_equal(img, imgres) @pytest.mark.parametrize('max_epoch,batch_size', [(10, 2), (20, 4), (1, 1)]) def test_cifar_data_loaders(max_epoch, batch_size): np.random.seed(42) class TestDataset(data.Dataset): def __init__(self, train=True, *args, **kwargs): self.train = train self.args = args self.kwargs = kwargs if self.train: self.train_data = (np.random.random_sample((20, 32, 32, 3)) * 255).astype(np.uint8) else: self.test_data = (np.random.random_sample((10, 32, 32, 3)) * 255).astype(np.uint8) self.transform = lambda val: val def __getitem__(self, idx): img = self.train_data[idx] if self.train else self.test_data[idx] img = Image.fromarray(img) img = self.transform(img) return img, np.array(idx) def __len__(self): return len(self.train_data) if self.train else len(self.test_data) max_epoch = 10 batch_size = 2 workers = 0 train_loader, val_loader = get_cifar_data_loaders(TestDataset, '', max_epoch, batch_size, workers=workers) xsize = (batch_size, 3, 32, 32) ysize = (batch_size, ) count = 0 for x, y in train_loader: count += 1 assert x.shape == xsize assert y.shape == ysize assert count == max_epoch assert count == len(train_loader) count = 0 for x, y in val_loader: count += 1 assert x.shape == xsize assert y.shape == ysize assert count == len(val_loader.dataset) // batch_size assert count == len(val_loader) ================================================ FILE: memcnn/data/tests/test_sampling.py ================================================ import pytest from memcnn.data.sampling import NSamplesRandomSampler import torch.utils.data as data import numpy as np @pytest.mark.parametrize('nsamples,data_samples', [(1, 1), (14, 10), (10, 14), (5, 1), (1, 5), (0, 10), (np.array(4, dtype=np.int64), 12), (np.int64(4), 12), (np.array(12, dtype=np.int64), 3), (np.int64(12), 3)]) @pytest.mark.parametrize('assign_after_creation', [False, True]) def test_random_sampler(nsamples, data_samples, assign_after_creation): class TestDataset(data.Dataset): def __init__(self, elements): self.elements = elements def __getitem__(self, idx): return idx, idx def __len__(self): return self.elements datasrc = TestDataset(data_samples) sampler = NSamplesRandomSampler(datasrc, nsamples=nsamples if not assign_after_creation else -1) if assign_after_creation: sampler.nsamples = nsamples count = 0 elements = [] for e in sampler: elements.append(e) count += 1 if count % data_samples == 0: assert len(np.unique(elements)) == len(elements) elements = [] assert count == nsamples assert len(sampler) == nsamples assert sampler.__len__() == nsamples ================================================ FILE: memcnn/examples/minimal.py ================================================ import torch import torch.nn as nn import memcnn # define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d class ExampleOperation(nn.Module): def __init__(self, channels): super(ExampleOperation, self).__init__() self.seq = nn.Sequential( nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(3, 3), padding=1), nn.BatchNorm2d(num_features=channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.seq(x) # generate some random input data (batch_size, num_channels, y_elements, x_elements) X = torch.rand(2, 10, 8, 8) # application of the operation(s) the normal way model_normal = ExampleOperation(channels=10) model_normal.eval() Y = model_normal(X) # turn the ExampleOperation invertible using an additive coupling invertible_module = memcnn.AdditiveCoupling( Fm=ExampleOperation(channels=10 // 2), Gm=ExampleOperation(channels=10 // 2) ) # test that it is actually a valid invertible module (has a valid inverse method) assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape) # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=True, keep_input_inverse=True) # by default the module is set to training, the following sets this to evaluation # note that this is required to pass input tensors to the model with requires_grad=False (inference only) invertible_module_wrapper.eval() # test that the wrapped module is also a valid invertible module assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape) # compute the forward pass using the wrapper Y2 = invertible_module_wrapper.forward(X) # the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2 X2 = invertible_module_wrapper.inverse(Y2) # test that the input and approximation are similar assert torch.allclose(X, X2, atol=1e-06) ================================================ FILE: memcnn/examples/test_examples.py ================================================ import torch import sys def test_minimal(): import minimal # Input and inversed output should be approximately the same assert torch.allclose(minimal.X, minimal.X2, atol=1e-06) # Output of the wrapped invertible module is unlikely to match the normal output of F assert not torch.allclose(minimal.Y2, minimal.Y) # Cleanup minimal module and variables del minimal.X del minimal.Y del minimal.Y2 del minimal.X2 del minimal del sys.modules['minimal'] ================================================ FILE: memcnn/experiment/__init__.py ================================================ ================================================ FILE: memcnn/experiment/factory.py ================================================ import json import copy def get_attr_from_module(pclass): pclass = pclass.rsplit(".", 1) mod = __import__(pclass[0], fromlist=[str(pclass[1])]) return getattr(mod, pclass[1]) def load_experiment_config(experiments_file, experiment_tags): with open(experiments_file, 'r') as f: data = json.load(f) d = {} for tag in experiment_tags: _inject_items(build_dict(data, tag), d) return d def _inject_items(tempdict, d): """inject tempdict into d""" for k, v in tempdict.items(): if isinstance(v, dict): if k not in d: d[k] = {} d[k] = _inject_items(v, d[k]) else: d[k] = v return d def build_dict(experiments_dict, experiment_name, classhist=None): tempdict = experiments_dict[experiment_name] if classhist is None: classhist = [] classhist.append(experiment_name) if not ('base' in tempdict) or (tempdict['base'] is None): return copy.deepcopy(tempdict) elif tempdict['base'] in classhist: raise RuntimeError('Circular dependency found...') else: d = build_dict(experiments_dict, tempdict['base'], classhist) return _inject_items(tempdict, d) def experiment_config_parser(d, data_dir, workers=None): trainer = get_attr_from_module(d['trainer']) model = get_attr_from_module(d['model']) model_params = copy.deepcopy(d['model_params']) if 'block' in model_params: model_params['block'] = get_attr_from_module(model_params['block']) model = model(**model_params) optimizer = get_attr_from_module(d['optimizer']) optimizer = optimizer(model.parameters(), **d['optimizer_params']) dl_params = copy.deepcopy(d['data_loader_params']) dl_params['dataset'] = get_attr_from_module(dl_params['dataset']) dl_params['data_dir'] = data_dir dl_params['workers'] = dl_params['workers'] if workers is None else workers train_loader, val_loader = get_attr_from_module(d['data_loader'])(**dl_params) trainer_params = {} if 'trainer_params' in d: trainer_params = copy.deepcopy(d['trainer_params']) if 'loss' in trainer_params: trainer_params['loss'] = get_attr_from_module(trainer_params['loss'])() trainer_params = dict( train_loader=train_loader, test_loader=val_loader, **trainer_params ) return model, optimizer, trainer, trainer_params ================================================ FILE: memcnn/experiment/manager.py ================================================ import os import glob import torch import logging import shutil import numpy as np class ExperimentManager(object): def __init__(self, experiment_dir, model=None, optimizer=None): self.logger = logging.getLogger(type(self).__name__) self.experiment_dir = experiment_dir self.model = model self.optimizer = optimizer self.model_dir = os.path.join(self.experiment_dir, "state", "model") self.optim_dir = os.path.join(self.experiment_dir, "state", "optimizer") self.log_dir = os.path.join(self.experiment_dir, "log") self.dirs = (self.experiment_dir, self.model_dir, self.log_dir, self.optim_dir) def make_dirs(self): for d in self.dirs: if not os.path.exists(d): os.makedirs(d) assert(self.all_dirs_exists()) # nosec def delete_dirs(self): for d in self.dirs: if os.path.exists(d): shutil.rmtree(d) assert(not self.any_dir_exists()) # nosec def any_dir_exists(self): return any([os.path.exists(d) for d in self.dirs]) def all_dirs_exists(self): return all([os.path.exists(d) for d in self.dirs]) def save_model_state(self, epoch): model_fname = os.path.join(self.model_dir, "{}.pt".format(epoch)) self.logger.info("Saving model state to: {}".format(model_fname)) torch.save(self.model.state_dict(), model_fname) def load_model_state(self, epoch): model_fname = os.path.join(self.model_dir, "{}.pt".format(epoch)) self.logger.info("Loading model state from: {}".format(model_fname)) self.model.load_state_dict(torch.load(model_fname)) def save_optimizer_state(self, epoch): optim_fname = os.path.join(self.optim_dir, "{}.pt".format(epoch)) self.logger.info("Saving optimizer state to: {}".format(optim_fname)) torch.save(self.optimizer.state_dict(), optim_fname) def load_optimizer_state(self, epoch): optim_fname = os.path.join(self.optim_dir, "{}.pt".format(epoch)) self.logger.info("Loading optimizer state from {}".format(optim_fname)) self.optimizer.load_state_dict(torch.load(optim_fname)) def save_train_state(self, epoch): self.save_model_state(epoch) self.save_optimizer_state(epoch) def load_train_state(self, epoch): self.load_model_state(epoch) self.load_optimizer_state(epoch) def get_last_model_iteration(self): return np.array([0] + [int(os.path.basename(e).split(".")[0]) for e in glob.glob(os.path.join(self.model_dir, "*.pt"))]).max() def load_last_train_state(self): self.load_train_state(self.get_last_model_iteration()) ================================================ FILE: memcnn/experiment/tests/__init__.py ================================================ ================================================ FILE: memcnn/experiment/tests/test_factory.py ================================================ import pytest import os import memcnn.experiment.factory from memcnn.config import Config def test_get_attr_from_module(): a = memcnn.experiment.factory.get_attr_from_module('memcnn.experiment.factory.get_attr_from_module') assert a is memcnn.experiment.factory.get_attr_from_module def test_load_experiment_config(): cfg_fname = os.path.join(Config.get_dir(), 'experiments.json') memcnn.experiment.factory.load_experiment_config(cfg_fname, ['cifar10', 'resnet110']) @pytest.mark.skip(reason="Covered more efficiently by test_train.test_run_experiment") def test_experiment_config_parser(tmp_path): tmp_data_dir = tmp_path / "tmpdata" cfg_fname = os.path.join(Config.get_dir(), 'experiments.json') cfg = memcnn.experiment.factory.load_experiment_config(cfg_fname, ['cifar10', 'resnet110']) memcnn.experiment.factory.experiment_config_parser(cfg, str(tmp_data_dir), workers=None) def test_circular_dependency(tmp_path): p = str(tmp_path / "circular.json") content = u'{ "circ": { "base": "circ" } }' with open(p, 'w') as fh: fh.write(content) with open(p, 'r') as fh: assert fh.read() == content with pytest.raises(RuntimeError): memcnn.experiment.factory.load_experiment_config(p, ['circ']) ================================================ FILE: memcnn/experiment/tests/test_manager.py ================================================ from memcnn.experiment.manager import ExperimentManager import torch.nn def test_experiment_manager(tmp_path): exp_dir = tmp_path / "test_exp_dir" man = ExperimentManager(str(exp_dir)) assert man.model is None assert man.optimizer is None man.make_dirs() assert exp_dir.exists() assert (exp_dir / "log").exists() assert (exp_dir / "state" / "model").exists() assert (exp_dir / "state" / "optimizer").exists() assert man.all_dirs_exists() assert man.any_dir_exists() man.delete_dirs() assert not exp_dir.exists() assert not (exp_dir / "log").exists() assert not (exp_dir / "state" / "model").exists() assert not (exp_dir / "state" / "optimizer").exists() assert not man.all_dirs_exists() assert not man.any_dir_exists() man.make_dirs() man.model = torch.nn.Conv2d(2, 1, 3) w = man.model.weight.clone() man.save_model_state(0) with torch.no_grad(): man.model.weight.zero_() man.save_model_state(100) assert not man.model.weight.equal(w) assert man.get_last_model_iteration() == 100 man.load_model_state(0) assert man.model.weight.equal(w) optimizer = torch.optim.SGD(man.model.parameters(), lr=0.01, momentum=0.1) man.optimizer = optimizer man.save_train_state(100) w = man.model.weight.clone() sd = man.optimizer.state_dict().copy() man.model.train() x = torch.ones(5, 2, 5, 5) x.requires_grad = True y = torch.ones(5, 1, 3, 3) y.requires_grad = False ypred = man.model(x) loss = torch.nn.MSELoss()(ypred, y) man.optimizer.zero_grad() loss.backward() man.optimizer.step() man.save_train_state(101) assert not man.model.weight.equal(w) assert sd != man.optimizer.state_dict() w2 = man.model.weight.clone() sd2 = man.optimizer.state_dict().copy() man.load_train_state(100) assert man.model.weight.equal(w) assert sd == man.optimizer.state_dict() man.load_last_train_state() # should be 101 assert not man.model.weight.equal(w) assert sd != man.optimizer.state_dict() assert man.model.weight.equal(w2) def retrieve_mom_buffer(sd): keys = [e for e in sd['state'].keys()] if len(keys) == 0: return torch.zero(0) else: return sd['state'][keys[0]]['momentum_buffer'] assert torch.equal(retrieve_mom_buffer(sd2), retrieve_mom_buffer(man.optimizer.state_dict())) ================================================ FILE: memcnn/models/__init__.py ================================================ ================================================ FILE: memcnn/models/additive.py ================================================ import warnings import torch import torch.nn as nn import copy from torch import set_grad_enabled class AdditiveCoupling(nn.Module): def __init__(self, Fm, Gm=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1): """ This computes the output :math:`y` on forward given input :math:`x` and arbitrary modules :math:`Fm` and :math:`Gm` according to: :math:`(x1, x2) = x` :math:`y1 = x1 + Fm(x2)` :math:`y2 = x2 + Gm(y1)` :math:`y = (y1, y2)` Parameters ---------- Fm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function Gm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module) implementation_fwd : :obj:`int` Switch between different Additive Operation implementations for forward pass. Default = -1 implementation_bwd : :obj:`int` Switch between different Additive Operation implementations for inverse pass. Default = -1 split_dim : :obj:`int` Dimension to split the input tensors on. Default = 1, generally corresponding to channels. """ super(AdditiveCoupling, self).__init__() # mirror the passed module, without parameter sharing... if Gm is None: Gm = copy.deepcopy(Fm) self.Gm = Gm self.Fm = Fm self.implementation_fwd = implementation_fwd self.implementation_bwd = implementation_bwd self.split_dim = split_dim if implementation_bwd != -1 or implementation_fwd != -1: warnings.warn("Other implementations than the default (-1) are now deprecated.", DeprecationWarning) def forward(self, x): args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()] if self.implementation_fwd == 0: out = AdditiveBlockFunction.apply(*args) elif self.implementation_fwd == 1: out = AdditiveBlockFunction2.apply(*args) elif self.implementation_fwd == -1: x1, x2 = torch.chunk(x, 2, dim=self.split_dim) x1, x2 = x1.contiguous(), x2.contiguous() fmd = self.Fm.forward(x2) y1 = x1 + fmd gmd = self.Gm.forward(y1) y2 = x2 + gmd out = torch.cat([y1, y2], dim=self.split_dim) else: raise NotImplementedError("Selected implementation ({}) not implemented..." .format(self.implementation_fwd)) return out def inverse(self, y): args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()] if self.implementation_bwd == 0: x = AdditiveBlockInverseFunction.apply(*args) elif self.implementation_bwd == 1: x = AdditiveBlockInverseFunction2.apply(*args) elif self.implementation_bwd == -1: y1, y2 = torch.chunk(y, 2, dim=self.split_dim) y1, y2 = y1.contiguous(), y2.contiguous() gmd = self.Gm.forward(y1) x2 = y2 - gmd fmd = self.Fm.forward(x2) x1 = y1 - fmd x = torch.cat([x1, x2], dim=self.split_dim) else: raise NotImplementedError("Inverse for selected implementation ({}) not implemented..." .format(self.implementation_bwd)) return x class AdditiveBlock(AdditiveCoupling): def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1): warnings.warn("This class has been deprecated. Use the AdditiveCoupling class instead.", DeprecationWarning) super(AdditiveBlock, self).__init__(Fm=Fm, Gm=Gm, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) class AdditiveBlockFunction(torch.autograd.Function): @staticmethod def forward(ctx, xin, Fm, Gm, *weights): """Forward pass computes: {x1, x2} = x y1 = x1 + Fm(x2) y2 = x2 + Gm(y1) output = {y1, y2} Parameters ---------- ctx : torch.autograd.Function The backward pass context object x : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this function """ # check if possible to partition into two equally sized partitions assert(xin.shape[1] % 2 == 0) # nosec # store partition size, Fm and Gm functions in context ctx.Fm = Fm ctx.Gm = Gm with torch.no_grad(): x = xin.detach() # partition in two equally sized set of channels x1, x2 = torch.chunk(x, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # compute outputs fmr = Fm.forward(x2) y1 = x1 + fmr x1.set_() del x1 gmr = Gm.forward(y1) y2 = x2 + gmr x2.set_() del x2 output = torch.cat([y1, y2], dim=1) ctx.save_for_backward(xin, output) return output @staticmethod def backward(ctx, grad_output): # pragma: no cover # retrieve weight references Fm, Gm = ctx.Fm, ctx.Gm # retrieve input and output references xin, output = ctx.saved_tensors x = xin.detach() x1, x2 = torch.chunk(x, 2, dim=1) GWeights = [p for p in Gm.parameters()] # partition output gradient also on channels assert grad_output.shape[1] % 2 == 0 # nosec with set_grad_enabled(True): # compute outputs building a sub-graph x1.requires_grad_() x2.requires_grad_() y1 = x1 + Fm.forward(x2) y2 = x2 + Gm.forward(y1) y = torch.cat([y1, y2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(y, (x1, x2 ) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output) GWgrads = dd[2:2+len(GWeights)] FWgrads = dd[2+len(GWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AdditiveBlockInverseFunction(torch.autograd.Function): @staticmethod def forward(cty, y, Fm, Gm, *weights): """Forward pass computes: {y1, y2} = y x2 = y2 - Gm(y1) x1 = y1 - Fm(x2) output = {x1, x2} Parameters ---------- cty : torch.autograd.Function The backward pass context object y : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert(y.shape[1] % 2 == 0) # nosec # store partition size, Fm and Gm functions in context cty.Fm = Fm cty.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels y1, y2 = torch.chunk(y, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # compute outputs gmr = Gm.forward(y1) x2 = y2 - gmr y2.set_() del y2 fmr = Fm.forward(x2) x1 = y1 - fmr y1.set_() del y1 output = torch.cat([x1, x2], dim=1) x1.set_() x2.set_() del x1, x2 # save the (empty) input and (non-empty) output variables cty.save_for_backward(y.data, output) return output @staticmethod def backward(cty, grad_output): # pragma: no cover # retrieve weight references Fm, Gm = cty.Fm, cty.Gm # retrieve input and output references yin, output = cty.saved_tensors y = yin.detach() y1, y2 = torch.chunk(y, 2, dim=1) FWeights = [p for p in Fm.parameters()] # partition output gradient also on channels assert grad_output.shape[1] % 2 == 0 # nosec with set_grad_enabled(True): # compute outputs building a sub-graph y2.requires_grad = True y1.requires_grad = True x2 = y2 - Gm.forward(y1) x1 = y1 - Fm.forward(x2) x = torch.cat([x1, x2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(x, (y2, y1 ) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output) FWgrads = dd[2:2+len(FWeights)] GWgrads = dd[2+len(FWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AdditiveBlockFunction2(torch.autograd.Function): @staticmethod def forward(ctx, xin, Fm, Gm, *weights): """Forward pass computes: {x1, x2} = x y1 = x1 + Fm(x2) y2 = x2 + Gm(y1) output = {y1, y2} Parameters ---------- ctx : torch.autograd.Function The backward pass context object x : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert xin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context ctx.Fm = Fm ctx.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels x = xin.detach() x1, x2 = torch.chunk(x, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # compute outputs fmr = Fm.forward(x2) y1 = x1 + fmr x1.set_() del x1 gmr = Gm.forward(y1) y2 = x2 + gmr x2.set_() del x2 output = torch.cat([y1, y2], dim=1).detach_() # save the input and output variables ctx.save_for_backward(x, output) return output @staticmethod def backward(ctx, grad_output): # pragma: no cover Fm, Gm = ctx.Fm, ctx.Gm # are all variable objects now x, output = ctx.saved_tensors with torch.no_grad(): y1, y2 = torch.chunk(output, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # partition output gradient also on channels assert(grad_output.shape[1] % 2 == 0) # nosec y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1) y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, x2_stop, GW, FW # Also recompute inputs (x1, x2) from outputs (y1, y2) with set_grad_enabled(True): z1_stop = y1.detach() z1_stop.requires_grad = True G_z1 = Gm.forward(z1_stop) x2 = y2 - G_z1 x2_stop = x2.detach() x2_stop.requires_grad = True F_x2 = Fm.forward(x2_stop) x1 = y1 - F_x2 x1_stop = x1.detach() x1_stop.requires_grad = True # compute outputs building a sub-graph y1 = x1_stop + F_x2 y2 = x2_stop + G_z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(y2, (z1_stop,) + tuple(Gm.parameters()), y2_grad, retain_graph=False) z1_grad = dd[0] + y1_grad GWgrads = dd[1:] dd = torch.autograd.grad(y1, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False) FWgrads = dd[2:] x2_grad = dd[1] + y2_grad x1_grad = dd[0] grad_input = torch.cat([x1_grad, x2_grad], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AdditiveBlockInverseFunction2(torch.autograd.Function): @staticmethod def forward(cty, y, Fm, Gm, *weights): """Forward pass computes: {y1, y2} = y x2 = y2 - Gm(y1) x1 = y1 - Fm(x2) output = {x1, x2} Parameters ---------- cty : torch.autograd.Function The backward pass context object y : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert(y.shape[1] % 2 == 0) # nosec # store partition size, Fm and Gm functions in context cty.Fm = Fm cty.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels y1, y2 = torch.chunk(y, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # compute outputs gmr = Gm.forward(y1) x2 = y2 - gmr y2.set_() del y2 fmr = Fm.forward(x2) x1 = y1 - fmr y1.set_() del y1 output = torch.cat([x1, x2], dim=1).detach_() # save the input and output variables cty.save_for_backward(y, output) return output @staticmethod def backward(cty, grad_output): # pragma: no cover Fm, Gm = cty.Fm, cty.Gm # are all variable objects now y, output = cty.saved_tensors with torch.no_grad(): x1, x2 = torch.chunk(output, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # partition output gradient also on channels assert(grad_output.shape[1] % 2 == 0) # nosec x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1) x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, y1_stop, GW, FW # Also recompute inputs (y1, y2) from outputs (x1, x2) with set_grad_enabled(True): z1_stop = x2.detach() z1_stop.requires_grad = True F_z1 = Fm.forward(z1_stop) y1 = x1 + F_z1 y1_stop = y1.detach() y1_stop.requires_grad = True G_y1 = Gm.forward(y1_stop) y2 = x2 + G_y1 y2_stop = y2.detach() y2_stop.requires_grad = True # compute outputs building a sub-graph z1 = y2_stop - G_y1 x1 = y1_stop - F_z1 x2 = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(x1, (z1_stop,) + tuple(Fm.parameters()), x1_grad) z1_grad = dd[0] + x2_grad FWgrads = dd[1:] dd = torch.autograd.grad(x2, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False) GWgrads = dd[2:] y1_grad = dd[1] + x1_grad y2_grad = dd[0] grad_input = torch.cat([y1_grad, y2_grad], dim=1) return (grad_input, None, None) + FWgrads + GWgrads ================================================ FILE: memcnn/models/affine.py ================================================ import torch import torch.nn as nn import copy import warnings from torch import set_grad_enabled warnings.filterwarnings(action='ignore', category=UserWarning) class AffineAdapterNaive(nn.Module): """ Naive Affine adapter Outputs exp(f(x)), f(x) given f(.) and x """ def __init__(self, module): super(AffineAdapterNaive, self).__init__() self.f = module def forward(self, x): t = self.f(x) s = torch.exp(t) return s, t class AffineAdapterSigmoid(nn.Module): """ Sigmoid based affine adapter Partitions the output h of f(x) = h into s and t by extracting every odd and even channel Outputs sigmoid(s), t """ def __init__(self, module): super(AffineAdapterSigmoid, self).__init__() self.f = module def forward(self, x): h = self.f(x) assert h.shape[1] % 2 == 0 # nosec scale = torch.sigmoid(h[:, 1::2] + 2.0) shift = h[:, 0::2] return scale, shift class AffineCoupling(nn.Module): def __init__(self, Fm, Gm=None, adapter=None, implementation_fwd=-1, implementation_bwd=-1, split_dim=1): """ This computes the output :math:`y` on forward given input :math:`x` and arbitrary modules :math:`Fm` and :math:`Gm` according to: :math:`(x1, x2) = x` :math:`(log({s1}), t1) = Fm(x2)` :math:`s1 = exp(log({s1}))` :math:`y1 = s1 * x1 + t1` :math:`(log({s2}), t2) = Gm(y1)` :math:`s2 = exp(log({s2}))` :math:`y2 = s2 * x2 + t2` :math:`y = (y1, y2)` Parameters ---------- Fm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function Gm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Gm is used as a Module) adapter : :obj:`torch.nn.Module` class An optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x) s, t are respectively the scale and shift tensors for the affine coupling. implementation_fwd : :obj:`int` Switch between different Affine Operation implementations for forward pass. Default = -1 implementation_bwd : :obj:`int` Switch between different Affine Operation implementations for inverse pass. Default = -1 split_dim : :obj:`int` Dimension to split the input tensors on. Default = 1, generally corresponding to channels. """ super(AffineCoupling, self).__init__() # mirror the passed module, without parameter sharing... if Gm is None: Gm = copy.deepcopy(Fm) # apply the adapter class if it is given self.Gm = adapter(Gm) if adapter is not None else Gm self.Fm = adapter(Fm) if adapter is not None else Fm self.implementation_fwd = implementation_fwd self.implementation_bwd = implementation_bwd self.split_dim = split_dim if implementation_bwd != -1 or implementation_fwd != -1: warnings.warn("Other implementations than the default (-1) are now deprecated.", DeprecationWarning) def forward(self, x): args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()] if self.implementation_fwd == 0: out = AffineBlockFunction.apply(*args) elif self.implementation_fwd == 1: out = AffineBlockFunction2.apply(*args) elif self.implementation_fwd == -1: x1, x2 = torch.chunk(x, 2, dim=self.split_dim) x1, x2 = x1.contiguous(), x2.contiguous() fmr1, fmr2 = self.Fm.forward(x2) y1 = (x1 * fmr1) + fmr2 gmr1, gmr2 = self.Gm.forward(y1) y2 = (x2 * gmr1) + gmr2 out = torch.cat([y1, y2], dim=self.split_dim) else: raise NotImplementedError("Selected implementation ({}) not implemented..." .format(self.implementation_fwd)) return out def inverse(self, y): args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()] if self.implementation_bwd == 0: x = AffineBlockInverseFunction.apply(*args) elif self.implementation_bwd == 1: x = AffineBlockInverseFunction2.apply(*args) elif self.implementation_bwd == -1: y1, y2 = torch.chunk(y, 2, dim=self.split_dim) y1, y2 = y1.contiguous(), y2.contiguous() gmr1, gmr2 = self.Gm.forward(y1) x2 = (y2 - gmr2) / gmr1 fmr1, fmr2 = self.Fm.forward(x2) x1 = (y1 - fmr2) / fmr1 x = torch.cat([x1, x2], dim=self.split_dim) else: raise NotImplementedError("Inverse for selected implementation ({}) not implemented..." .format(self.implementation_bwd)) return x class AffineBlock(AffineCoupling): def __init__(self, Fm, Gm=None, implementation_fwd=1, implementation_bwd=1): warnings.warn("This class has been deprecated. Use the AffineCoupling class instead.", DeprecationWarning) super(AffineBlock, self).__init__(Fm=Fm, Gm=Gm, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) class AffineBlockFunction(torch.autograd.Function): @staticmethod def forward(ctx, xin, Fm, Gm, *weights): """Forward pass for the affine block computes: {x1, x2} = x {log_s1, t1} = Fm(x2) s1 = exp(log_s1) y1 = s1 * x1 + t1 {log_s2, t2} = Gm(y1) s2 = exp(log_s2) y2 = s2 * x2 + t2 output = {y1, y2} Parameters ---------- ctx : torch.autograd.function.RevNetFunctionBackward The backward pass context object x : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this function """ # check if possible to partition into two equally sized partitions assert xin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context ctx.Fm = Fm ctx.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels x = xin.detach() x1, x2 = torch.chunk(x, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # compute outputs x2var = x2 fmr1, fmr2 = Fm.forward(x2var) y1 = (x1 * fmr1) + fmr2 x1.set_() del x1 y1var = y1 gmr1, gmr2 = Gm.forward(y1var) y2 = (x2 * gmr1) + gmr2 x2.set_() del x2 output = torch.cat([y1, y2], dim=1).detach_() # save the (empty) input and (non-empty) output variables ctx.save_for_backward(xin, output) return output @staticmethod def backward(ctx, grad_output): # pragma: no cover # retrieve weight references Fm, Gm = ctx.Fm, ctx.Gm # retrieve input and output references xin, output = ctx.saved_tensors x = xin.detach() x1, x2 = torch.chunk(x.detach(), 2, dim=1) GWeights = [p for p in Gm.parameters()] # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) # nosec with set_grad_enabled(True): # compute outputs building a sub-graph x1.requires_grad = True x2.requires_grad = True fmr1, fmr2 = Fm.forward(x2) y1 = x1 * fmr1 + fmr2 gmr1, gmr2 = Gm.forward(y1) y2 = x2 * gmr1 + gmr2 y = torch.cat([y1, y2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(y, (x1, x2) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output) GWgrads = dd[2:2 + len(GWeights)] FWgrads = dd[2 + len(GWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AffineBlockInverseFunction(torch.autograd.Function): @staticmethod def forward(cty, yin, Fm, Gm, *weights): """Forward inverse pass for the affine block computes: {y1, y2} = y {log_s2, t2} = Gm(y1) s2 = exp(log_s2) x2 = (y2 - t2) / s2 {log_s1, t1} = Fm(x2) s1 = exp(log_s1) x1 = (y1 - t1) / s1 output = {x1, x2} Parameters ---------- cty : torch.autograd.function.RevNetInverseFunctionBackward The backward pass context object y : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert yin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context cty.Fm = Fm cty.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels y = yin.detach() y1, y2 = torch.chunk(y, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # compute outputs y1var = y1 gmr1, gmr2 = Gm.forward(y1var) x2 = (y2 - gmr2) / gmr1 y2.set_() del y2 x2var = x2 fmr1, fmr2 = Fm.forward(x2var) x1 = (y1 - fmr2) / fmr1 y1.set_() del y1 output = torch.cat([x1, x2], dim=1).detach_() # save input and output variables cty.save_for_backward(yin, output) return output @staticmethod def backward(cty, grad_output): # pragma: no cover # retrieve weight references Fm, Gm = cty.Fm, cty.Gm # retrieve input and output references yin, output = cty.saved_tensors y = yin.detach() y1, y2 = torch.chunk(y.detach(), 2, dim=1) FWeights = [p for p in Gm.parameters()] # partition output gradient also on channels assert grad_output.shape[1] % 2 == 0 # nosec with set_grad_enabled(True): # compute outputs building a sub-graph y2.requires_grad = True y1.requires_grad = True gmr1, gmr2 = Gm.forward(y1) # x2 = (y2 - gmr2) / gmr1 fmr1, fmr2 = Fm.forward(x2) x1 = (y1 - fmr2) / fmr1 x = torch.cat([x1, x2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(x, (y2, y1) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output) FWgrads = dd[2:2 + len(FWeights)] GWgrads = dd[2 + len(FWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) return (grad_input, None, None) + FWgrads + GWgrads class AffineBlockFunction2(torch.autograd.Function): @staticmethod def forward(ctx, xin, Fm, Gm, *weights): """Forward pass for the affine block computes: {x1, x2} = x {log_s1, t1} = Fm(x2) s1 = exp(log_s1) y1 = s1 * x1 + t1 {log_s2, t2} = Gm(y1) s2 = exp(log_s2) y2 = s2 * x2 + t2 output = {y1, y2} Parameters ---------- ctx : torch.autograd.function.RevNetFunctionBackward The backward pass context object x : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert xin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context ctx.Fm = Fm ctx.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels x = xin.detach() x1, x2 = torch.chunk(x, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # compute outputs x2var = x2 fmr1, fmr2 = Fm.forward(x2var) y1 = x1 * fmr1 + fmr2 x1.set_() del x1 y1var = y1 gmr1, gmr2 = Gm.forward(y1var) y2 = x2 * gmr1 + gmr2 x2.set_() del x2 output = torch.cat([y1, y2], dim=1).detach_() # save the input and output variables ctx.save_for_backward(xin, output) return output @staticmethod def backward(ctx, grad_output): # pragma: no cover Fm, Gm = ctx.Fm, ctx.Gm # are all variable objects now x, output = ctx.saved_tensors with set_grad_enabled(False): y1, y2 = torch.chunk(output, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) # nosec y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1) y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, x2_stop, GW, FW # Also recompute inputs (x1, x2) from outputs (y1, y2) with set_grad_enabled(True): z1_stop = y1 z1_stop.requires_grad = True G_z11, G_z12 = Gm.forward(z1_stop) x2 = (y2 - G_z12) / G_z11 x2_stop = x2.detach() x2_stop.requires_grad = True F_x21, F_x22 = Fm.forward(x2_stop) x1 = (y1 - F_x22) / F_x21 x1_stop = x1.detach() x1_stop.requires_grad = True # compute outputs building a sub-graph z1 = x1_stop * F_x21 + F_x22 y2_ = x2_stop * G_z11 + G_z12 y1_ = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(y2_, (z1_stop,) + tuple(Gm.parameters()), y2_grad) z1_grad = dd[0] + y1_grad GWgrads = dd[1:] dd = torch.autograd.grad(y1_, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False) FWgrads = dd[2:] x2_grad = dd[1] + y2_grad x1_grad = dd[0] grad_input = torch.cat([x1_grad, x2_grad], dim=1) y1_.detach_() y2_.detach_() del y1_, y2_ return (grad_input, None, None) + FWgrads + GWgrads class AffineBlockInverseFunction2(torch.autograd.Function): @staticmethod def forward(cty, yin, Fm, Gm, *weights): """Forward pass for the affine block computes: Parameters ---------- cty : torch.autograd.function.RevNetInverseFunctionBackward The backward pass context object y : TorchTensor Input tensor. Must have channels (2nd dimension) that can be partitioned in two equal partitions Fm : nn.Module Module to use for computation, must retain dimensions such that Fm(X)=Y, X.shape == Y.shape Gm : nn.Module Module to use for computation, must retain dimensions such that Gm(X)=Y, X.shape == Y.shape *weights : TorchTensor weights for Fm and Gm in that order {Fm_w1, ... Fm_wn, Gm_w1, ... Gm_wn} Note ---- All tensor/autograd variable input arguments and the output are TorchTensors for the scope of this fuction """ # check if possible to partition into two equally sized partitions assert yin.shape[1] % 2 == 0 # nosec # store partition size, Fm and Gm functions in context cty.Fm = Fm cty.Gm = Gm with torch.no_grad(): # partition in two equally sized set of channels y = yin.detach() y1, y2 = torch.chunk(y, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # compute outputs y1var = y1 gmr1, gmr2 = Gm.forward(y1var) x2 = (y2 - gmr2) / gmr1 y2.set_() del y2 x2var = x2 fmr1, fmr2 = Fm.forward(x2var) x1 = (y1 - fmr2) / fmr1 y1.set_() del y1 output = torch.cat([x1, x2], dim=1).detach_() # save the input and output variables cty.save_for_backward(yin, output) return output @staticmethod def backward(cty, grad_output): # pragma: no cover Fm, Gm = cty.Fm, cty.Gm # are all variable objects now y, output = cty.saved_tensors with set_grad_enabled(False): x1, x2 = torch.chunk(output, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) # nosec x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1) x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, y1_stop, GW, FW # Also recompute inputs (y1, y2) from outputs (x1, x2) with set_grad_enabled(True): z1_stop = x2 z1_stop.requires_grad = True F_z11, F_z12 = Fm.forward(z1_stop) y1 = x1 * F_z11 + F_z12 y1_stop = y1.detach() y1_stop.requires_grad = True G_y11, G_y12 = Gm.forward(y1_stop) y2 = x2 * G_y11 + G_y12 y2_stop = y2.detach() y2_stop.requires_grad = True # compute outputs building a sub-graph z1 = (y2_stop - G_y12) / G_y11 x1_ = (y1_stop - F_z12) / F_z11 x2_ = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(x1_, (z1_stop,) + tuple(Fm.parameters()), x1_grad) z1_grad = dd[0] + x2_grad FWgrads = dd[1:] dd = torch.autograd.grad(x2_, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False) GWgrads = dd[2:] y1_grad = dd[1] + x1_grad y2_grad = dd[0] grad_input = torch.cat([y1_grad, y2_grad], dim=1) return (grad_input, None, None) + FWgrads + GWgrads ================================================ FILE: memcnn/models/resnet.py ================================================ """ResNet/RevNet implementation used for The Reversible Residual Network Implemented in PyTorch instead of TensorFlow. @inproceedings{gomez17revnet, author = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse}, title = {The Reversible Residual Network: Backpropagation without Storing Activations} booktitle = {NIPS}, year = {2017}, } Github: https://github.com/renmengye/revnet-public Author: Sil van de Leemput """ import torch.nn as nn import math from memcnn.models.revop import InvertibleModuleWrapper, create_coupling __all__ = ['ResNet', 'BasicBlock', 'Bottleneck', 'RevBasicBlock', 'RevBottleneck', 'BasicBlockSub', 'BottleneckSub', 'conv3x3', 'batch_norm'] def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def batch_norm(x): """match Tensorflow batch norm settings""" return nn.BatchNorm2d(x, momentum=0.99, eps=0.001) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(BasicBlock, self).__init__() self.basicblock_sub = BasicBlockSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.basicblock_sub(x) if self.downsample is not None: residual = self.downsample(x) out += residual return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(Bottleneck, self).__init__() self.bottleneck_sub = BottleneckSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.bottleneck_sub(x) if self.downsample is not None: residual = self.downsample(x) out += residual return out class RevBasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(RevBasicBlock, self).__init__() if downsample is None and stride == 1: gm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation) fm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation) coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive') self.revblock = InvertibleModuleWrapper(fn=coupling, keep_input=False) else: self.basicblock_sub = BasicBlockSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride def forward(self, x): if self.downsample is not None: out = self.basicblock_sub(x) residual = self.downsample(x) out += residual else: out = self.revblock(x) return out class RevBottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(RevBottleneck, self).__init__() if downsample is None and stride == 1: gm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation) fm = BottleneckSub(inplanes // 2, planes // 2, stride, noactivation) coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive') self.revblock = InvertibleModuleWrapper(fn=coupling, keep_input=False) else: self.bottleneck_sub = BottleneckSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride def forward(self, x): if self.downsample is not None: out = self.bottleneck_sub(x) residual = self.downsample(x) out += residual else: out = self.revblock(x) return out class BottleneckSub(nn.Module): def __init__(self, inplanes, planes, stride=1, noactivation=False): super(BottleneckSub, self).__init__() self.noactivation = noactivation if not self.noactivation: self.bn1 = batch_norm(inplanes) self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn2 = batch_norm(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn3 = batch_norm(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.relu = nn.ReLU(inplace=True) def forward(self, x): if not self.noactivation: x = self.bn1(x) x = self.relu(x) x = self.conv1(x) x = self.bn2(x) x = self.relu(x) x = self.conv2(x) x = self.bn3(x) x = self.relu(x) x = self.conv3(x) return x class BasicBlockSub(nn.Module): def __init__(self, inplanes, planes, stride=1, noactivation=False): super(BasicBlockSub, self).__init__() self.noactivation = noactivation if not self.noactivation: self.bn1 = batch_norm(inplanes) self.conv1 = conv3x3(inplanes, planes, stride) self.bn2 = batch_norm(planes) self.conv2 = conv3x3(planes, planes) self.relu = nn.ReLU(inplace=True) def forward(self, x): if not self.noactivation: x = self.bn1(x) x = self.relu(x) x = self.conv1(x) x = self.bn2(x) x = self.relu(x) x = self.conv2(x) return x class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, channels_per_layer=None, strides=None, init_max_pool=False, init_kernel_size=7, batch_norm_fix=True, implementation=0): if channels_per_layer is None: channels_per_layer = [2 ** (i + 6) for i in range(len(layers))] channels_per_layer = [channels_per_layer[0]] + channels_per_layer if strides is None: strides = [2] * len(channels_per_layer) self.batch_norm_fix = batch_norm_fix self.channels_per_layer = channels_per_layer self.strides = strides self.init_max_pool = init_max_pool self.implementation = implementation assert(len(self.channels_per_layer) == len(layers) + 1) # nosec self.inplanes = channels_per_layer[0] # 64 by default super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=init_kernel_size, stride=strides[0], padding=(init_kernel_size - 1) // 2, bias=False) self.bn1 = batch_norm(self.inplanes) self.relu = nn.ReLU(inplace=False) if self.init_max_pool: self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, channels_per_layer[1], layers[0], stride=strides[1], noactivation=True) self.layer2 = self._make_layer(block, channels_per_layer[2], layers[1], stride=strides[2]) self.layer3 = self._make_layer(block, channels_per_layer[3], layers[2], stride=strides[3]) self.has_4_layers = len(layers) >= 4 if self.has_4_layers: self.layer4 = self._make_layer(block, channels_per_layer[4], layers[3], stride=strides[4]) self.bn_final = batch_norm(self.inplanes) # channels_per_layer[-1]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(channels_per_layer[-1] * block.expansion, num_classes) self.configure() self.init_weights() def init_weights(self): """Initialization using He initialization""" for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.reset_parameters() def configure(self): """Initialization specific configuration settings""" for m in self.modules(): if isinstance(m, InvertibleModuleWrapper): m.implementation = self.implementation elif isinstance(m, nn.BatchNorm2d): if self.batch_norm_fix: m.momentum = 0.99 m.eps = 0.001 else: m.momentum = 0.1 m.eps = 1e-05 def _make_layer(self, block, planes, blocks, stride=1, noactivation=False): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), batch_norm(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample, noactivation)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.init_max_pool: x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) if self.has_4_layers: x = self.layer4(x) x = self.bn_final(x) x = self.relu(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x ================================================ FILE: memcnn/models/revop.py ================================================ import functools import warnings import numpy as np import torch import torch.nn as nn from memcnn.models.additive import AdditiveCoupling from memcnn.models.affine import AffineCoupling try: import torch.amp def custom_fwd(fwd=None, *, cast_inputs=None, device_type='cuda'): if fwd is None: return functools.partial(custom_fwd, cast_inputs=cast_inputs, device_type=device_type) return torch.amp.custom_fwd(fwd, cast_inputs=cast_inputs, device_type=device_type) def custom_bwd(bwd, device_type='cuda'): return torch.amp.custom_bwd(bwd, device_type=device_type) except ModuleNotFoundError: def custom_fwd(fwd=None, *, cast_inputs=None, device_type='cuda'): if fwd is None: return functools.partial(custom_fwd, cast_inputs=cast_inputs, device_type=device_type) return functools.partial(fwd) def custom_bwd(bwd, device_type='cuda'): return functools.partial(bwd, device_type=device_type) class InvertibleCheckpointFunction(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, fn, fn_inverse, keep_input, num_bwd_passes, preserve_rng_state, num_inputs, *inputs_and_weights): # store in context ctx.fn = fn ctx.fn_inverse = fn_inverse ctx.keep_input = keep_input ctx.weights = inputs_and_weights[num_inputs:] ctx.num_bwd_passes = num_bwd_passes ctx.preserve_rng_state = preserve_rng_state ctx.num_inputs = num_inputs inputs = inputs_and_weights[:num_inputs] if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*inputs) ctx.input_requires_grad = [element.requires_grad for element in inputs] with torch.no_grad(): # Makes a detached copy which shares the storage x = [element.detach() for element in inputs] outputs = ctx.fn(*x) if not isinstance(outputs, tuple): outputs = (outputs,) # Detaches y in-place (inbetween computations can now be discarded) detached_outputs = tuple([element.detach_() for element in outputs]) # clear memory from inputs if not ctx.keep_input: # PyTorch 1.0+ way to clear storage for element in inputs: element.storage().resize_(0) # store these tensor nodes for backward pass ctx.inputs = [inputs] * num_bwd_passes ctx.outputs = [detached_outputs] * num_bwd_passes return detached_outputs @staticmethod @custom_bwd def backward(ctx, *grad_outputs): # pragma: no cover if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("InvertibleCheckpointFunction is not compatible with .grad(), please use .backward() if possible") # retrieve input and output tensor nodes if len(ctx.outputs) == 0: raise RuntimeError("Trying to perform backward on the InvertibleCheckpointFunction for more than " "{} times! Try raising `num_bwd_passes` by one.".format(ctx.num_bwd_passes)) inputs = ctx.inputs.pop() outputs = ctx.outputs.pop() # recompute input if necessary if not ctx.keep_input: # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) # recompute input with torch.no_grad(): inputs_inverted = ctx.fn_inverse(*outputs) if not isinstance(inputs_inverted, tuple): inputs_inverted = (inputs_inverted,) for element_original, element_inverted in zip(inputs, inputs_inverted): element_original.storage().resize_(int(np.prod(element_original.size()))) element_original.set_(element_inverted) # compute gradients with torch.set_grad_enabled(True): detached_inputs = tuple([element.detach().requires_grad_() for element in inputs]) temp_output = ctx.fn(*detached_inputs) if not isinstance(temp_output, tuple): temp_output = (temp_output,) gradients = torch.autograd.grad(outputs=temp_output, inputs=detached_inputs + ctx.weights, grad_outputs=grad_outputs) # Setting the gradients manually on the inputs and outputs (mimic backwards) for element, element_grad in zip(inputs, gradients[:ctx.num_inputs]): element.grad = element_grad for element, element_grad in zip(outputs, grad_outputs): element.grad = element_grad return (None, None, None, None, None, None) + gradients class InvertibleModuleWrapper(nn.Module): def __init__(self, fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1, disable=False, preserve_rng_state=False): """ The InvertibleModuleWrapper which enables memory savings during training by exploiting the invertible properties of the wrapped module. Parameters ---------- fn : :obj:`torch.nn.Module` A torch.nn.Module which has a forward and an inverse function implemented with :math:`x == m.inverse(m.forward(x))` keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. num_bwd_passes :obj:`int`, optional Number of backward passes to retain a link with the output. After the last backward pass the output is discarded and memory is freed. Warning: if this value is raised higher than the number of required passes memory will not be freed correctly anymore and the training process can quickly run out of memory. Hence, The typical use case is to keep this at 1, until it raises an error for raising this value. disable : :obj:`bool`, optional This will disable using the InvertibleCheckpointFunction altogether. Essentially this renders the function as :math:`y = fn(x)` without any of the memory savings. Setting this to true will also ignore the keep_input and keep_input_inverse properties. preserve_rng_state : :obj:`bool`, optional Setting this will ensure that the same RNG state is used during reconstruction of the inputs. I.e. if keep_input = False on forward or keep_input_inverse = False on inverse. By default this is False since most invertible modules should have a valid inverse and hence are deterministic. Attributes ---------- keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. Note ---- The InvertibleModuleWrapper can be used with mixed-precision training using :obj:`torch.cuda.amp.autocast` as of torch v1.6 and above. However, inputs will always be cast to :obj:`torch.float32` internally. This is done to minimize autocasting inputs to a different datatype which usually results in a disconnected computation graph and will raise an error on the backward pass. """ super(InvertibleModuleWrapper, self).__init__() self.disable = disable self.keep_input = keep_input self.keep_input_inverse = keep_input_inverse self.num_bwd_passes = num_bwd_passes self.preserve_rng_state = preserve_rng_state self._fn = fn def forward(self, *xin): """Forward operation :math:`R(x) = y` Parameters ---------- *xin : :obj:`torch.Tensor` tuple Input torch tensor(s). Returns ------- :obj:`torch.Tensor` tuple Output torch tensor(s) *y. """ if not self.disable: y = InvertibleCheckpointFunction.apply( self._fn.forward, self._fn.inverse, self.keep_input, self.num_bwd_passes, self.preserve_rng_state, len(xin), *(xin + tuple([p for p in self._fn.parameters() if p.requires_grad]))) else: y = self._fn(*xin) # If the layer only has one input, we unpack the tuple again if isinstance(y, tuple) and len(y) == 1: return y[0] return y def inverse(self, *yin): """Inverse operation :math:`R^{-1}(y) = x` Parameters ---------- *yin : :obj:`torch.Tensor` tuple Input torch tensor(s). Returns ------- :obj:`torch.Tensor` tuple Output torch tensor(s) *x. """ if not self.disable: x = InvertibleCheckpointFunction.apply( self._fn.inverse, self._fn.forward, self.keep_input_inverse, self.num_bwd_passes, self.preserve_rng_state, len(yin), *(yin + tuple([p for p in self._fn.parameters() if p.requires_grad]))) else: x = self._fn.inverse(*yin) # If the layer only has one input, we unpack the tuple again if isinstance(x, tuple) and len(x) == 1: return x[0] return x class ReversibleBlock(InvertibleModuleWrapper): def __init__(self, Fm, Gm=None, coupling='additive', keep_input=False, keep_input_inverse=False, implementation_fwd=-1, implementation_bwd=-1, adapter=None): """The ReversibleBlock Warning ------- This class has been deprecated. Use the more flexible InvertibleModuleWrapper class. Note ---- The `implementation_fwd` and `implementation_bwd` parameters can be set to one of the following implementations: * -1 Naive implementation without reconstruction on the backward pass. * 0 Memory efficient implementation, compute gradients directly. * 1 Memory efficient implementation, similar to approach in Gomez et al. 2017. Parameters ---------- Fm : :obj:`torch.nn.Module` A torch.nn.Module encapsulating an arbitrary function Gm : :obj:`torch.nn.Module`, optional A torch.nn.Module encapsulating an arbitrary function (If not specified a deepcopy of Fm is used as a Module) coupling : :obj:`str`, optional Type of coupling ['additive', 'affine']. Default = 'additive' keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. implementation_fwd : :obj:`int`, optional Switch between different Operation implementations for forward training (Default = 1). If using the naive implementation (-1) then `keep_input` should be True. implementation_bwd : :obj:`int`, optional Switch between different Operation implementations for backward training (Default = 1). If using the naive implementation (-1) then `keep_input_inverse` should be True. adapter : :obj:`class`, optional Only relevant when using the 'affine' coupling. Should be a class of type :obj:`torch.nn.Module` that serves as an optional wrapper class A for Fm and Gm which must output s, t = A(x) with shape(s) = shape(t) = shape(x). s, t are respectively the scale and shift tensors for the affine coupling. Attributes ---------- keep_input : :obj:`bool`, optional Set to retain the input information on forward, by default it can be discarded since it will be reconstructed upon the backward pass. keep_input_inverse : :obj:`bool`, optional Set to retain the input information on inverse, by default it can be discarded since it will be reconstructed upon the backward pass. Raises ------ NotImplementedError If an unknown coupling or implementation is given. """ warnings.warn("This class has been deprecated. Use the more flexible InvertibleModuleWrapper class", DeprecationWarning) fn = create_coupling(Fm=Fm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter) super(ReversibleBlock, self).__init__(fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse) def create_coupling(Fm, Gm=None, coupling='additive', implementation_fwd=-1, implementation_bwd=-1, adapter=None): if coupling == 'additive': fn = AdditiveCoupling(Fm, Gm, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) elif coupling == 'affine': fn = AffineCoupling(Fm, Gm, adapter=adapter, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) else: raise NotImplementedError('Unknown coupling method: %s' % coupling) return fn def is_invertible_module(module_in, test_input_shape, test_input_dtype=torch.float32, atol=1e-6, random_seed=42): """Test if a :obj:`torch.nn.Module` is invertible Parameters ---------- module_in : :obj:`torch.nn.Module` A torch.nn.Module to test. test_input_shape : :obj:`tuple` of :obj:`int` or :obj:`tuple` of :obj:`tuple` of :obj:`int` Dimensions of test tensor(s) object to perform the test with. test_input_dtype : :obj:`torch.dtype`, optional Data type of test tensor object to perform the test with. atol : :obj:`float`, optional Tolerance value used for comparing the outputs. random_seed : :obj:`int`, optional Use this value to seed the pseudo-random test_input_shapes with different numbers. Returns ------- :obj:`bool` True if the input module is invertible, False otherwise. """ if isinstance(module_in, InvertibleModuleWrapper): module_in = module_in._fn if not hasattr(module_in, "inverse"): return False def _type_check_input_shape(test_input_shape): if isinstance(test_input_shape, (tuple, list)): if all([isinstance(e, int) for e in test_input_shape]): return True elif all([isinstance(e, (tuple, list)) for e in test_input_shape]): return all([isinstance(ee, int) for e in test_input_shape for ee in e]) else: return False else: return False if not _type_check_input_shape(test_input_shape): raise ValueError("test_input_shape should be of type Tuple[int, ...] or " "Tuple[Tuple[int, ...], ...], but {} found".format(type(test_input_shape))) if not isinstance(test_input_shape[0], (tuple, list)): test_input_shape = (test_input_shape,) def _check_inputs_allclose(inputs, reference, atol): for inp, ref in zip(inputs, reference): if not torch.allclose(inp, ref, atol=atol): return False return True def _pack_if_no_tuple(x): if not isinstance(x, tuple): return (x, ) return x with torch.no_grad(): torch.manual_seed(random_seed) test_inputs = tuple([torch.rand(shape, dtype=test_input_dtype) for shape in test_input_shape]) if any([torch.equal(torch.zeros_like(e), e) for e in test_inputs]): # pragma: no cover warnings.warn("Some inputs were detected to be all zeros, you might want to set a different random_seed.") if not _check_inputs_allclose(_pack_if_no_tuple(module_in.inverse(*_pack_if_no_tuple(module_in(*test_inputs)))), test_inputs, atol=atol): return False test_outputs = _pack_if_no_tuple(module_in(*test_inputs)) if any([torch.equal(torch.zeros_like(e), e) for e in test_outputs]): # pragma: no cover warnings.warn("Some outputs were detected to be all zeros, you might want to set a different random_seed.") if not _check_inputs_allclose(_pack_if_no_tuple(module_in(*_pack_if_no_tuple(module_in.inverse(*test_outputs)))), test_outputs, atol=atol): # pragma: no cover return False test_reconstructed_inputs = _pack_if_no_tuple(module_in.inverse(*test_outputs)) def _test_shared(inputs, outputs, msg): shared = set(inputs) shared_outputs = set(outputs) if len(inputs) != len(shared): # pragma: no cover warnings.warn("Some inputs (*x) share the same tensor, are you sure this is what you want? ({})".format(msg)) if len(outputs) != len(shared_outputs): warnings.warn("Some outputs (*y) share the same tensor, are you sure this is what you want? ({})".format(msg)) if any([inp in shared for inp in shared_outputs]): warnings.warn("Some inputs (*x) and outputs (*y) share the same tensor, this is typically not a " "good function to use with memcnn.InvertibleModuleWrapper as it might increase memory usage. " "E.g. an identity function. ({})".format(msg)) _test_shared(test_inputs, test_outputs, msg="forward") _test_shared(test_reconstructed_inputs, test_outputs, msg="inverse") return True # We can't know if the run_fn will internally move some args to different devices, # which would require logic to preserve rng states for those devices as well. # We could paranoically stash and restore ALL the rng states for all visible devices, # but that seems very wasteful for most cases. Compromise: Stash the RNG state for # the device of all Tensor args. # # To consider: maybe get_device_states and set_device_states should reside in torch/random.py? # # get_device_states and set_device_states cannot be imported from torch.utils.checkpoint, since it was not # present in older versions, so we include a copy here. def get_device_states(*args): # This will not error out if "arg" is a CPU tensor or a non-tensor type because # the conditionals short-circuit. fwd_gpu_devices = list(set(arg.get_device() for arg in args if isinstance(arg, torch.Tensor) and arg.is_cuda)) fwd_gpu_states = [] for device in fwd_gpu_devices: with torch.cuda.device(device): fwd_gpu_states.append(torch.cuda.get_rng_state()) return fwd_gpu_devices, fwd_gpu_states def set_device_states(devices, states): for device, state in zip(devices, states): with torch.cuda.device(device): torch.cuda.set_rng_state(state) ================================================ FILE: memcnn/models/tests/__init__.py ================================================ ================================================ FILE: memcnn/models/tests/test_amp.py ================================================ import pytest import torch from torch import nn import torch.optim as optim import torchvision from torch.utils.checkpoint import checkpoint from torchvision.models.resnet import resnet18, BasicBlock import torchvision.transforms as transforms import memcnn try: from torch.cuda.amp import autocast, GradScaler except ModuleNotFoundError: pass class InvertibleBlock(nn.Module): def __init__(self, block, keep_input, enabled=True): super().__init__() self.invertible_block = memcnn.InvertibleModuleWrapper( fn=memcnn.AdditiveCoupling(block), keep_input=keep_input, keep_input_inverse=keep_input, disable=not enabled, ) def forward(self, x, inverse=False): if inverse: return self.invertible_block.inverse(x) else: return self.invertible_block(x) class CheckPointBlock(nn.Module): def __init__(self, block): super().__init__() self.invertible_module = memcnn.AdditiveCoupling(block) def forward(self, x, inverse=False): return checkpoint(self.invertible_module.forward, x) @pytest.mark.skipif( condition="autocast" not in locals(), reason="torch.cuda.amp could not be found. torch version is < 1.6.", ) @pytest.mark.parametrize( "use_checkpointing, inv_enabled", ((True, False), (False, True,), (False, False)) ) @pytest.mark.parametrize("amp_enabled", (False, True)) def test_cuda_amp(tmp_path, inv_enabled, amp_enabled, use_checkpointing): if not torch.cuda.is_available() and amp_enabled: pytest.skip("This test requires a GPU to be available") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = resnet18(num_classes=10) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) trainset = torchvision.datasets.CIFAR10( root=tmp_path, train=True, download=True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=4, shuffle=True, num_workers=2 ) # Replace layer1 if not use_checkpointing: model.layer1 = nn.Sequential( InvertibleBlock(BasicBlock(32, 32), keep_input=False, enabled=inv_enabled), InvertibleBlock(BasicBlock(32, 32), keep_input=False, enabled=inv_enabled), ) else: model.layer1 = nn.Sequential( CheckPointBlock(BasicBlock(32, 32)), CheckPointBlock(BasicBlock(32, 32)) ) model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) scaler = GradScaler(enabled=amp_enabled) for i, data in enumerate(trainloader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() with autocast(enabled=amp_enabled): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() break ================================================ FILE: memcnn/models/tests/test_couplings.py ================================================ import torch import torch.nn import pytest import copy import warnings from memcnn import create_coupling, InvertibleModuleWrapper from memcnn.models.tests.test_revop import set_seeds from memcnn.models.tests.test_models import SubModule from memcnn.models.affine import AffineAdapterNaive, AffineBlock from memcnn.models.additive import AdditiveBlock @pytest.mark.parametrize('coupling', ['additive', 'affine']) @pytest.mark.parametrize('bwd', [False, True]) @pytest.mark.parametrize('implementation', [-1, 0, 1]) def test_coupling_implementations_against_reference(coupling, bwd, implementation): """Test if similar gradients and weights results are obtained after similar training for the couplings""" with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=DeprecationWarning) for seed in range(10): set_seeds(seed) X = torch.rand(2, 4, 5, 5) # define models and their copies c1 = torch.nn.Conv2d(2, 2, 3, padding=1) c2 = torch.nn.Conv2d(2, 2, 3, padding=1) c1_2 = copy.deepcopy(c1) c2_2 = copy.deepcopy(c2) # are weights between models the same, but do they differ between convolutions? assert torch.equal(c1.weight, c1_2.weight) assert torch.equal(c2.weight, c2_2.weight) assert torch.equal(c1.bias, c1_2.bias) assert torch.equal(c2.bias, c2_2.bias) assert not torch.equal(c1.weight, c2.weight) # define optimizers optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1) optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1) for e in [c1, c2, c1_2, c2_2]: e.train() # define an arbitrary reversible function and define graph for model 1 XX = X.detach().clone().requires_grad_() coupling_fn = create_coupling(Fm=c1, Gm=c2, coupling=coupling, implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive) Y = coupling_fn.inverse(XX) if bwd else coupling_fn.forward(XX) loss = torch.mean(Y) # define the reversible function without custom backprop and define graph for model 2 XX2 = X.detach().clone().requires_grad_() coupling_fn2 = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=implementation, implementation_bwd=implementation, adapter=AffineAdapterNaive) Y2 = coupling_fn2.inverse(XX2) if bwd else coupling_fn2.forward(XX2) loss2 = torch.mean(Y2) # compute gradients manually grads = torch.autograd.grad(loss2, (XX2, c1_2.weight, c2_2.weight, c1_2.bias, c2_2.bias), None, retain_graph=True) # compute gradients using backward and perform optimization model 2 loss2.backward() optim2.step() # gradients computed manually match those of the .backward() pass assert torch.equal(c1_2.weight.grad, grads[1]) assert torch.equal(c2_2.weight.grad, grads[2]) assert torch.equal(c1_2.bias.grad, grads[3]) assert torch.equal(c2_2.bias.grad, grads[4]) # weights differ after training a single model? assert not torch.equal(c1.weight, c1_2.weight) assert not torch.equal(c2.weight, c2_2.weight) assert not torch.equal(c1.bias, c1_2.bias) assert not torch.equal(c2.bias, c2_2.bias) # compute gradients and perform optimization model 1 loss.backward() optim1.step() # weights are approximately the same after training both models? assert torch.allclose(c1.weight.detach(), c1_2.weight.detach()) assert torch.allclose(c2.weight.detach(), c2_2.weight.detach()) assert torch.allclose(c1.bias.detach(), c1_2.bias.detach()) assert torch.allclose(c2.bias.detach(), c2_2.bias.detach()) # gradients are approximately the same after training both models? assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach()) assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach()) assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach()) assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach()) fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) Yout = fn.inverse(XX) if bwd else fn.forward(XX) loss = torch.mean(Yout) loss.backward() assert XX.storage().size() > 0 fn2 = InvertibleModuleWrapper(fn=coupling_fn2, keep_input=False, keep_input_inverse=False) Yout2 = fn2.inverse(XX2) if bwd else fn2.forward(XX2) loss = torch.mean(Yout2) loss.backward() assert XX2.storage().size() > 0 def test_legacy_additive_coupling(): with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=DeprecationWarning) AdditiveBlock(Fm=SubModule()) def test_legacy_affine_coupling(): with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=DeprecationWarning) AffineBlock(Fm=SubModule()) ================================================ FILE: memcnn/models/tests/test_is_invertible_module.py ================================================ import pytest import torch from memcnn import is_invertible_module, InvertibleModuleWrapper, AdditiveCoupling from memcnn.models.tests.test_models import IdentityInverse, MultiSharedOutputs, SubModule def test_is_invertible_module_with_invalid_inverse(): fn = IdentityInverse(multiply_inverse=True) with torch.no_grad(): fn.factor.zero_() assert not is_invertible_module(fn, test_input_shape=(12, 12)) @pytest.mark.parametrize("random_seed", [1, 42, 900000]) def test_is_invertible_module_random_seeds(random_seed): fn = IdentityInverse(multiply_forward=True, multiply_inverse=True) assert is_invertible_module(fn, test_input_shape=(1, ), random_seed=random_seed) def test_is_invertible_module_shared_outputs(): fnb = MultiSharedOutputs() X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_() with pytest.warns(UserWarning): assert is_invertible_module(fnb, test_input_shape=(X.shape,), atol=1e-6) def test_is_invertible_module_shared_tensors(): fn = IdentityInverse() rm = InvertibleModuleWrapper(fn=fn, keep_input=True, keep_input_inverse=True) X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_() with pytest.warns(UserWarning): assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) rm.forward(X) fn.multiply_forward = True rm.forward(X) with pytest.warns(UserWarning): assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) rm.inverse(X) fn.multiply_inverse = True rm.inverse(X) assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) def test_is_invertible_module(): X = torch.zeros(1, 10, 10, 10) assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), test_input_shape=X.shape) fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1) assert is_invertible_module(fn, test_input_shape=X.shape) class FakeInverse(torch.nn.Module): def forward(self, x): return x * 4 def inverse(self, y): return y * 8 assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape) def test_is_invertible_module_wrapped(): X = torch.zeros(1, 10, 10, 10) assert not is_invertible_module(InvertibleModuleWrapper(torch.nn.Conv2d(10, 10, kernel_size=(1, 1))), test_input_shape=X.shape) fn = InvertibleModuleWrapper(AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1)) assert is_invertible_module(fn, test_input_shape=X.shape) class FakeInverse(torch.nn.Module): def forward(self, x): return x * 4 def inverse(self, y): return y * 8 assert not is_invertible_module(InvertibleModuleWrapper(FakeInverse()), test_input_shape=X.shape) @pytest.mark.parametrize("input_shape", ( "string", (2.3, 1.4), None, True, ((1, 3, ), (12.4)), ((1, 3, ), False) )) def test_is_invertible_module_type_check_input_shapes(input_shape): with pytest.raises(ValueError): is_invertible_module(module_in=IdentityInverse(multiply_forward=True, multiply_inverse=True), test_input_shape=input_shape) ================================================ FILE: memcnn/models/tests/test_memory_saving.py ================================================ import pytest import gc import numpy as np import torch import torch.nn from memcnn.models.tests.test_models import SubModule, SubModuleStack @pytest.mark.parametrize('coupling', ['additive', 'affine']) @pytest.mark.parametrize('keep_input', [True, False]) @pytest.mark.parametrize('device', ['cpu', 'cuda']) def test_memory_saving_invertible_model_wrapper(device, coupling, keep_input): """Test memory saving of the invertible model wrapper * tests fitting a large number of images by creating a deep network requiring large intermediate feature maps for training * keep_input = False should use less memory than keep_input = True on both GPU and CPU RAM * input size in bytes: np.prod((2, 10, 10, 10)) * 4 / 1024.0 = 7.8125 kB for a depth=5 this yields 7.8125 * 5 = 39.0625 kB """ if device == 'cpu': pytest.skip('Unreliable metrics, should be fixed.') if device == 'cuda' and not torch.cuda.is_available(): pytest.skip('This test requires a GPU to be available') gc.disable() gc.collect() with torch.set_grad_enabled(True): dims = [2, 10, 10, 10] depth = 5 xx = torch.rand(*dims, device=device, dtype=torch.float32).requires_grad_() ytarget = torch.rand(*dims, device=device, dtype=torch.float32) # same convolution test network = SubModuleStack(SubModule(in_filters=5, out_filters=5), depth=depth, keep_input=keep_input, coupling=coupling, implementation_fwd=-1, implementation_bwd=-1) network.to(device) network.train() network.zero_grad() optim = torch.optim.RMSprop(network.parameters()) optim.zero_grad() mem_start = 0 if not device == 'cuda' else \ torch.cuda.memory_allocated() / float(1024 ** 2) y = network(xx) gc.collect() mem_after_forward = torch.cuda.memory_allocated() / float(1024 ** 2) loss = torch.nn.MSELoss()(y, ytarget) optim.zero_grad() loss.backward() optim.step() gc.collect() # mem_after_backward = torch.cuda.memory_allocated() / float(1024 ** 2) gc.enable() memuse = float(np.prod(dims + [depth, 4, ])) / float(1024 ** 2) measured_memuse = mem_after_forward - mem_start if keep_input: assert measured_memuse >= memuse else: assert measured_memuse < 1 # assert math.floor(mem_after_backward - mem_start) >= 9 ================================================ FILE: memcnn/models/tests/test_models.py ================================================ import torch import torch.nn from memcnn import create_coupling, InvertibleModuleWrapper class MultiplicationInverse(torch.nn.Module): def __init__(self, factor=2): super(MultiplicationInverse, self).__init__() self.factor = torch.nn.Parameter(torch.ones(1) * factor) def forward(self, x): return x * self.factor def inverse(self, y): return y / self.factor class IdentityInverse(torch.nn.Module): def __init__(self, multiply_forward=False, multiply_inverse=False): super(IdentityInverse, self).__init__() self.factor = torch.nn.Parameter(torch.ones(1)) self.multiply_forward = multiply_forward self.multiply_inverse = multiply_inverse def forward(self, x): if self.multiply_forward: return x * self.factor else: return x def inverse(self, y): if self.multiply_inverse: return y * self.factor else: return y class MultiSharedOutputs(torch.nn.Module): # pylint: disable=R0201 def forward(self, x): y = x * x return y, y # pylint: disable=R0201 def inverse(self, y, y2): x = torch.max(torch.sqrt(y), torch.sqrt(y2)) return x class SubModule(torch.nn.Module): def __init__(self, in_filters=5, out_filters=5): super(SubModule, self).__init__() self.bn = torch.nn.BatchNorm2d(out_filters) self.conv = torch.nn.Conv2d(in_filters, out_filters, (3, 3), padding=1) def forward(self, x): return self.bn(self.conv(x)) class SubModuleStack(torch.nn.Module): def __init__(self, Gm, coupling='additive', depth=10, implementation_fwd=-1, implementation_bwd=-1, keep_input=False, adapter=None, num_bwd_passes=1): super(SubModuleStack, self).__init__() fn = create_coupling(Fm=Gm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter) self.stack = torch.nn.ModuleList( [InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input, num_bwd_passes=num_bwd_passes) for _ in range(depth)] ) def forward(self, x): for rev_module in self.stack: x = rev_module.forward(x) return x def inverse(self, y): for rev_module in reversed(self.stack): y = rev_module.inverse(y) return y class SplitChannels(torch.nn.Module): def __init__(self, split_location): self.split_location = split_location super(SplitChannels, self).__init__() def forward(self, x): return (x[:, :self.split_location, :].clone(), x[:, self.split_location:, :].clone()) # pylint: disable=R0201 def inverse(self, x, y): return torch.cat([x, y], dim=1) class ConcatenateChannels(torch.nn.Module): def __init__(self, split_location): self.split_location = split_location super(ConcatenateChannels, self).__init__() # pylint: disable=R0201 def forward(self, x, y): return torch.cat([x, y], dim=1) def inverse(self, x): return (x[:, :self.split_location, :].clone(), x[:, self.split_location:, :].clone()) ================================================ FILE: memcnn/models/tests/test_multi.py ================================================ import pytest import torch from memcnn.models.revop import InvertibleModuleWrapper, is_invertible_module from memcnn.models.tests.test_models import SplitChannels, ConcatenateChannels @pytest.mark.parametrize('disable', [True, False]) def test_multi(disable): split = InvertibleModuleWrapper(SplitChannels(2), disable = disable) concat = InvertibleModuleWrapper(ConcatenateChannels(2), disable = disable) assert is_invertible_module(split, test_input_shape=(1, 3, 32, 32)) assert is_invertible_module(concat, test_input_shape=((1, 2, 32, 32), (1, 1, 32, 32))) conv_a = torch.nn.Conv2d(2, 2, 3) conv_b = torch.nn.Conv2d(1, 1, 3) x = torch.rand(1, 3, 32, 32) x.requires_grad = True a, b = split(x) a, b = conv_a(a), conv_b(b) y = concat(a, b) loss = torch.sum(y) loss.backward() ================================================ FILE: memcnn/models/tests/test_resnet.py ================================================ import pytest import torch from memcnn.models.resnet import ResNet, BasicBlock, Bottleneck, RevBasicBlock, RevBottleneck @pytest.mark.parametrize('block,batch_norm_fix', [(BasicBlock, True), (Bottleneck, False), (RevBasicBlock, False), (RevBottleneck, True)]) def test_resnet(block, batch_norm_fix): model = ResNet(block, [2, 2, 2, 2], num_classes=2, channels_per_layer=None, init_max_pool=True, batch_norm_fix=batch_norm_fix, strides=None) model.eval() with torch.no_grad(): x = torch.ones(2, 3, 32, 32) model.forward(x) ================================================ FILE: memcnn/models/tests/test_revop.py ================================================ import warnings import pytest import random import torch import torch.nn import numpy as np import copy from memcnn.models.affine import AffineAdapterNaive, AffineAdapterSigmoid, AffineCoupling from memcnn.models.revop import InvertibleModuleWrapper, ReversibleBlock, create_coupling, \ is_invertible_module, get_device_states, set_device_states from memcnn.models.additive import AdditiveCoupling from memcnn.models.tests.test_models import MultiplicationInverse, SubModule, SubModuleStack def set_seeds(seed): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) def is_memory_cleared(var, isclear, shape): if isclear: return var.storage().size() == 0 else: return var.storage().size() > 0 and var.shape == shape @pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('enabled', [True, False]) def test_get_set_device_states(device, enabled): shape = (1, 1, 10, 10) if not torch.cuda.is_available() and device == 'cuda': pytest.skip('This test requires a GPU to be available') X = torch.ones(shape, device=device) devices, states = get_device_states(X) assert len(states) == (1 if device == 'cuda' else 0) assert len(devices) == (1 if device == 'cuda' else 0) cpu_rng_state = torch.get_rng_state() Y = X * torch.rand(shape, device=device) with torch.random.fork_rng(devices=devices, enabled=True): if enabled: if device == 'cpu': torch.set_rng_state(cpu_rng_state) else: set_device_states(devices=devices, states=states) Y2 = X * torch.rand(shape, device=device) assert torch.equal(Y, Y2) == enabled @pytest.mark.parametrize('coupling', ['additive', 'affine']) def test_reversible_block_notimplemented(coupling): fm = torch.nn.Conv2d(10, 10, (3, 3), padding=1) X = torch.zeros(1, 20, 10, 10) with pytest.raises(NotImplementedError): with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=DeprecationWarning) f = ReversibleBlock(fm, coupling=coupling, implementation_bwd=0, implementation_fwd=-2, adapter=AffineAdapterNaive) assert isinstance(f, InvertibleModuleWrapper) f.forward(X) with pytest.raises(NotImplementedError): with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=DeprecationWarning) f = ReversibleBlock(fm, coupling=coupling, implementation_bwd=-2, implementation_fwd=0, adapter=AffineAdapterNaive) assert isinstance(f, InvertibleModuleWrapper) f.inverse(X) with pytest.raises(NotImplementedError): with warnings.catch_warnings(): warnings.simplefilter(action='ignore', category=DeprecationWarning) ReversibleBlock(fm, coupling='unknown', implementation_bwd=-2, implementation_fwd=0, adapter=AffineAdapterNaive) @pytest.mark.parametrize('fn', [ AdditiveCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1), AffineCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive), AffineCoupling(Fm=SubModule(out_filters=10), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterSigmoid), MultiplicationInverse() ]) @pytest.mark.parametrize('bwd', [False, True]) @pytest.mark.parametrize('keep_input', [False, True]) @pytest.mark.parametrize('keep_input_inverse', [False, True]) @pytest.mark.parametrize('preserve_rng_state', [False, True]) def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse, preserve_rng_state): """InvertibleModuleWrapper tests for the memory saving forward and backward passes * test inversion Y = RB(X) and X = RB.inverse(Y) * test training the block for a single step and compare weights for implementations: 0, 1 * test automatic discard of input X and its retrieval after the backward pass * test usage of BN to identify non-contiguous memory blocks """ for seed in range(10): set_seeds(seed) dims = (2, 10, 8, 8) data = torch.rand(*dims, dtype=torch.float32) target_data = torch.rand(*dims, dtype=torch.float32) assert is_invertible_module(fn, test_input_shape=data.shape, atol=1e-4) # test with zero padded convolution with torch.set_grad_enabled(True): X = data.clone().requires_grad_() Ytarget = target_data.clone() Xshape = X.shape rb = InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse, preserve_rng_state=preserve_rng_state) s_grad = [p.detach().clone() for p in rb.parameters()] rb.train() rb.zero_grad() optim = torch.optim.RMSprop(rb.parameters()) optim.zero_grad() if not bwd: Xin = X.clone().requires_grad_() Y = rb(Xin) Yrev = Y.detach().clone().requires_grad_() Xinv = rb.inverse(Yrev) else: Xin = X.clone().requires_grad_() Y = rb.inverse(Xin) Yrev = Y.detach().clone().requires_grad_() Xinv = rb(Yrev) loss = torch.nn.MSELoss()(Y, Ytarget) # has input been retained/discarded after forward (and backward) passes? if not bwd: assert is_memory_cleared(Yrev, not keep_input_inverse, Xshape) assert is_memory_cleared(Xin, not keep_input, Xshape) else: assert is_memory_cleared(Xin, not keep_input_inverse, Xshape) assert is_memory_cleared(Yrev, not keep_input, Xshape) optim.zero_grad() loss.backward() optim.step() assert Y.shape == Xshape assert X.detach().shape == data.shape assert torch.allclose(X.detach(), data, atol=1e-06) assert torch.allclose(X.detach(), Xinv.detach(), atol=1e-04) # Model is now trained and will differ grads = [p.detach().clone() for p in rb.parameters()] assert not torch.allclose(grads[0], s_grad[0]) @pytest.mark.parametrize('coupling,adapter', [('additive', None), ('affine', AffineAdapterNaive), ('affine', AffineAdapterSigmoid)]) def test_chained_invertible_module_wrapper(coupling, adapter): set_seeds(42) dims = (2, 10, 8, 8) data = torch.rand(*dims, dtype=torch.float32) target_data = torch.rand(*dims, dtype=torch.float32) with torch.set_grad_enabled(True): X = data.clone().requires_grad_() Ytarget = target_data.clone() Gm = SubModule(in_filters=5, out_filters=5 if coupling == 'additive' or adapter is AffineAdapterNaive else 10) rb = SubModuleStack(Gm, coupling=coupling, depth=2, keep_input=False, adapter=adapter, implementation_bwd=-1, implementation_fwd=-1) rb.train() optim = torch.optim.RMSprop(rb.parameters()) rb.zero_grad() optim.zero_grad() Xin = X.clone() Y = rb(Xin) loss = torch.nn.MSELoss()(Y, Ytarget) loss.backward() optim.step() assert not torch.isnan(loss) def test_chained_invertible_module_wrapper_shared_fwd_and_bwd_train_passes(): set_seeds(42) Gm = SubModule(in_filters=5, out_filters=5) rb_temp = SubModuleStack(Gm=Gm, coupling='additive', depth=5, keep_input=True, adapter=None, implementation_bwd=-1, implementation_fwd=-1) optim = torch.optim.SGD(rb_temp.parameters(), lr=0.01) initial_params = [p.detach().clone() for p in rb_temp.parameters()] initial_state = copy.deepcopy(rb_temp.state_dict()) initial_optim_state = copy.deepcopy(optim.state_dict()) dims = (2, 10, 8, 8) data = torch.rand(*dims, dtype=torch.float32) target_data = torch.rand(*dims, dtype=torch.float32) forward_outputs = [] inverse_outputs = [] for i in range(10): is_forward_pass = i % 2 == 0 set_seeds(42) rb = SubModuleStack(Gm=Gm, coupling='additive', depth=5, keep_input=True, adapter=None, implementation_bwd=-1, implementation_fwd=-1, num_bwd_passes=2) rb.train() with torch.no_grad(): for (name, p), p_initial in zip(rb.named_parameters(), initial_params): p.set_(p_initial) rb.load_state_dict(initial_state) optim = torch.optim.SGD(rb_temp.parameters(), lr=0.01) optim.load_state_dict(initial_optim_state) with torch.set_grad_enabled(True): X = data.detach().clone().requires_grad_() Ytarget = target_data.detach().clone() optim.zero_grad() if is_forward_pass: Y = rb(X) Xinv = rb.inverse(Y) Xinv2 = rb.inverse(Y) Xinv3 = rb.inverse(Y) else: Y = rb.inverse(X) Xinv = rb(Y) Xinv2 = rb(Y) Xinv3 = rb(Y) for item in [Xinv, Xinv2, Xinv3]: assert torch.allclose(X, item, atol=1e-04) loss = torch.nn.MSELoss()(Xinv, Ytarget) assert not torch.isnan(loss) assert Xinv2.grad is None assert Xinv3.grad is None loss.backward() assert Y.grad is not None assert Xinv.grad is not None assert Xinv2.grad is None assert Xinv3.grad is None loss2 = torch.nn.MSELoss()(Xinv2, Ytarget) assert not torch.isnan(loss2) loss2.backward() assert Xinv2.grad is not None optim.step() if is_forward_pass: forward_outputs.append(Y.detach().clone()) else: inverse_outputs.append(Y.detach().clone()) for i in range(4): assert torch.allclose(forward_outputs[-1], forward_outputs[i], atol=1e-06) assert torch.allclose(inverse_outputs[-1], inverse_outputs[i], atol=1e-06) @pytest.mark.parametrize("inverted", [False, True]) def test_invertible_module_wrapper_disabled_versus_enabled(inverted): set_seeds(42) Gm = SubModule(in_filters=5, out_filters=5) coupling_fn = create_coupling(Fm=Gm, Gm=Gm, coupling='additive', implementation_fwd=-1, implementation_bwd=-1) rb = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) rb2 = InvertibleModuleWrapper(fn=copy.deepcopy(coupling_fn), keep_input=False, keep_input_inverse=False) rb.eval() rb2.eval() rb2.disable = True with torch.no_grad(): dims = (2, 10, 8, 8) data = torch.rand(*dims, dtype=torch.float32) X, X2 = data.clone().detach().requires_grad_(), data.clone().detach().requires_grad_() if not inverted: Y = rb(X) Y2 = rb2(X2) else: Y = rb.inverse(X) Y2 = rb2.inverse(X2) assert torch.allclose(Y, Y2) assert is_memory_cleared(X, True, dims) assert is_memory_cleared(X2, False, dims) @pytest.mark.parametrize('coupling', ['additive', 'affine']) def test_invertible_module_wrapper_simple_inverse(coupling): """InvertibleModuleWrapper inverse test""" for seed in range(10): set_seeds(seed) # define some data X = torch.rand(2, 4, 5, 5).requires_grad_() # define an arbitrary reversible function coupling_fn = create_coupling(Fm=torch.nn.Conv2d(2, 2, 3, padding=1), coupling=coupling, implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive) fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) # compute output Y = fn.forward(X.clone()) # compute input from output X2 = fn.inverse(Y) # check that the inverted output and the original input are approximately similar assert torch.allclose(X2.detach(), X.detach(), atol=1e-06) @pytest.mark.parametrize('coupling', ['additive', 'affine']) def test_normal_vs_invertible_module_wrapper(coupling): """InvertibleModuleWrapper test if similar gradients and weights results are obtained after similar training""" for seed in range(10): set_seeds(seed) X = torch.rand(2, 4, 5, 5) # define models and their copies c1 = torch.nn.Conv2d(2, 2, 3, padding=1) c2 = torch.nn.Conv2d(2, 2, 3, padding=1) c1_2 = copy.deepcopy(c1) c2_2 = copy.deepcopy(c2) # are weights between models the same, but do they differ between convolutions? assert torch.equal(c1.weight, c1_2.weight) assert torch.equal(c2.weight, c2_2.weight) assert torch.equal(c1.bias, c1_2.bias) assert torch.equal(c2.bias, c2_2.bias) assert not torch.equal(c1.weight, c2.weight) # define optimizers optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1) optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1) for e in [c1, c2, c1_2, c2_2]: e.train() # define an arbitrary reversible function and define graph for model 1 Xin = X.clone().requires_grad_() coupling_fn = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive) fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) Y = fn.forward(Xin) loss2 = torch.mean(Y) # define the reversible function without custom backprop and define graph for model 2 XX = X.clone().detach().requires_grad_() x1, x2 = torch.chunk(XX, 2, dim=1) if coupling == 'additive': y1 = x1 + c1.forward(x2) y2 = x2 + c2.forward(y1) elif coupling == 'affine': fmr2 = c1.forward(x2) fmr1 = torch.exp(fmr2) y1 = (x1 * fmr1) + fmr2 gmr2 = c2.forward(y1) gmr1 = torch.exp(gmr2) y2 = (x2 * gmr1) + gmr2 else: raise NotImplementedError() YY = torch.cat([y1, y2], dim=1) loss = torch.mean(YY) # compute gradients manually grads = torch.autograd.grad(loss, (XX, c1.weight, c2.weight, c1.bias, c2.bias), None, retain_graph=True) # compute gradients and perform optimization model 2 loss.backward() optim1.step() # gradients computed manually match those of the .backward() pass assert torch.equal(c1.weight.grad, grads[1]) assert torch.equal(c2.weight.grad, grads[2]) assert torch.equal(c1.bias.grad, grads[3]) assert torch.equal(c2.bias.grad, grads[4]) # weights differ after training a single model? assert not torch.equal(c1.weight, c1_2.weight) assert not torch.equal(c2.weight, c2_2.weight) assert not torch.equal(c1.bias, c1_2.bias) assert not torch.equal(c2.bias, c2_2.bias) # compute gradients and perform optimization model 1 loss2.backward() optim2.step() # input is contiguous tests assert Xin.is_contiguous() assert Y.is_contiguous() # weights are approximately the same after training both models? assert torch.allclose(c1.weight.detach(), c1_2.weight.detach()) assert torch.allclose(c2.weight.detach(), c2_2.weight.detach()) assert torch.allclose(c1.bias.detach(), c1_2.bias.detach()) assert torch.allclose(c2.bias.detach(), c2_2.bias.detach()) # gradients are approximately the same after training both models? assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach()) assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach()) assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach()) assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach()) ================================================ FILE: memcnn/models/tests/test_split_dim.py ================================================ import pytest import torch from memcnn import AdditiveCoupling, AffineAdapterNaive, AffineCoupling class Check(torch.nn.Module): def __init__(self, dim, target_size): super(Check, self).__init__() self.dim = dim self.target_size = target_size def forward(self, fn_input): assert fn_input.size(self.dim) == self.target_size return fn_input @pytest.mark.parametrize('dimension', [None, 0, 1, 2]) @pytest.mark.parametrize('coupling', [AdditiveCoupling, AffineCoupling]) @pytest.mark.parametrize('input_size', [(2, 2, 2), (2, 4, 8, 12)]) def test_split_dim(dimension, coupling, input_size): dim = 1 if dimension is None else dimension module = Check(dim, input_size[dim] // 2) coupling_args = dict(adapter=AffineAdapterNaive) if coupling.__name__ == 'AffineCoupling' else dict() if dimension is not None: coupling_args["split_dim"] = dimension model = coupling(module, **coupling_args) inp = torch.randn(input_size, requires_grad=False) output = model(inp) assert inp.shape == output.shape ================================================ FILE: memcnn/train.py ================================================ import argparse import os import logging import torch from memcnn.config import Config from memcnn.experiment.manager import ExperimentManager from memcnn.experiment.factory import load_experiment_config, experiment_config_parser import memcnn.utils.log logger = logging.getLogger('train') def run_experiment(experiment_tags, data_dir, results_dir, start_fresh=False, use_cuda=False, workers=None, experiments_file=None, *args, **kwargs): if not os.path.exists(data_dir): raise RuntimeError('Cannot find data_dir directory: {}'.format(data_dir)) if not os.path.exists(results_dir): raise RuntimeError('Cannot find results_dir directory: {}'.format(results_dir)) cfg = load_experiment_config(experiments_file, experiment_tags) logger.info(cfg) model, optimizer, trainer, trainer_params = experiment_config_parser(cfg, workers=workers, data_dir=data_dir) experiment_dir = os.path.join(results_dir, '_'.join(experiment_tags)) manager = ExperimentManager(experiment_dir, model, optimizer) if start_fresh: logger.info('Starting fresh option enabled. Clearing all previous results...') manager.delete_dirs() manager.make_dirs() if use_cuda: manager.model = manager.model.cuda() import torch.backends.cudnn as cudnn cudnn.benchmark = True last_iter = manager.get_last_model_iteration() if last_iter > 0: logger.info('Continue experiment from iteration: {}'.format(last_iter)) manager.load_train_state(last_iter) trainer_params.update(kwargs) trainer(manager, start_iter=last_iter, use_cuda=use_cuda, *args, **trainer_params) def main(data_dir, results_dir): # setup logging memcnn.utils.log.setup(True) # specify defaults for arguments use_cuda = torch.cuda.is_available() workers = 16 experiments_file = os.path.join(os.path.dirname(__file__), 'config', 'experiments.json') start_fresh = False # parse arguments parser = argparse.ArgumentParser(description='Run memcnn experiments.') parser.add_argument('experiment_tags', type=str, nargs='+', help='Experiment tags to run and combine from the experiment config file') parser.add_argument('--workers', dest='workers', type=int, default=workers, help='Number of workers for data loading (Default: {})'.format(workers)) parser.add_argument('--results-dir', dest='results_dir', type=str, default=results_dir, help='Directory for storing results (Default: {})'.format(results_dir)) parser.add_argument('--data-dir', dest='data_dir', type=str, default=data_dir, help='Directory for input data (Default: {})'.format(data_dir)) parser.add_argument('--experiments-file', dest='experiments_file', type=str, default=experiments_file, help='Experiments file (Default: {})'.format(experiments_file)) parser.add_argument('--fresh', dest='start_fresh', action='store_true', default=start_fresh, help='Start with fresh experiment, clears all previous results (Default: {})' .format(start_fresh)) parser.add_argument('--no-cuda', dest='use_cuda', action='store_false', default=use_cuda, help='Always disables GPU use (Default: use when available)') args = parser.parse_args() if not use_cuda: logger.warning('CUDA is not available in the current configuration!!!') if not args.use_cuda: logger.warning('CUDA is disabled!!!') # run experiment given arguments run_experiment( args.experiment_tags, args.data_dir, args.results_dir, start_fresh=args.start_fresh, experiments_file=args.experiments_file, use_cuda=args.use_cuda, workers=args.workers) if __name__ == '__main__': # pragma: no cover config_fname = Config.get_filename() if not os.path.exists(config_fname) or not 'data_dir' in Config() or not 'results_dir' in Config(): print('The configuration file was not set correctly.\n') print('Please create a configuration file (json) at:\n {}\n'.format(config_fname)) print('The configuration file should be formatted as follows:\n\n' '{\n' ' "data_dir": "/home/user/data",\n' ' "results_dir": "/home/user/experiments"\n' '}\n') print('data_dir : location for storing the input training datasets') print('results_dir : location for storing the experiment files during training') else: main(data_dir=Config()['data_dir'], results_dir=Config()['results_dir']) ================================================ FILE: memcnn/trainers/__init__.py ================================================ ================================================ FILE: memcnn/trainers/classification.py ================================================ import time import logging import torch import numpy as np from memcnn.utils.stats import AverageMeter, accuracy from memcnn.utils.log import SummaryWriter logger = logging.getLogger('trainer') def validate(model, ceriterion, val_loader, device): """validation sub-loop""" model.eval() batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() with torch.no_grad(): for x, label in val_loader: x, label = x.to(device), label.to(device) vx, vl = x, label score = model(vx) loss = ceriterion(score, vl) prec1 = accuracy(score.data, label) losses.update(loss.item(), x.size(0)) top1.update(prec1[0][0], x.size(0)) batch_time.update(time.time() - end) end = time.time() logger.info('Test: [{0}/{0}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(len(val_loader), batch_time=batch_time, loss=losses, top1=top1)) return top1.avg, losses.avg def get_model_parameters_count(model): return np.sum([np.prod([int(e) for e in p.shape]) for p in model.parameters()]) def train(manager, train_loader, test_loader, start_iter, disp_iter=100, save_iter=10000, valid_iter=1000, use_cuda=False, loss=None): """train loop""" device = torch.device('cpu' if not use_cuda else 'cuda') model, optimizer = manager.model, manager.optimizer logger.info('Model parameters: {}'.format(get_model_parameters_count(model))) if use_cuda: model_mem_allocation = torch.cuda.memory_allocated(device) logger.info('Model memory allocation: {}'.format(model_mem_allocation)) else: model_mem_allocation = None writer = SummaryWriter(manager.log_dir) data_time = AverageMeter() batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() act_mem_activations = AverageMeter() ceriterion = loss # ensure train_loader enumerates to max_epoch max_iterations = train_loader.sampler.nsamples // train_loader.batch_size train_loader.sampler.nsamples = train_loader.sampler.nsamples - start_iter end = time.time() for ind, (x, label) in enumerate(train_loader): iteration = ind + 1 + start_iter if iteration > max_iterations: logger.info('maximum number of iterations reached: {}/{}'.format(iteration, max_iterations)) break if iteration == 40000 or iteration == 60000: for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 model.train() data_time.update(time.time() - end) end = time.time() x, label = x.to(device), label.to(device) vx, vl = x, label score = model(vx) loss = ceriterion(score, vl) if use_cuda: activation_mem_allocation = torch.cuda.memory_allocated(device) - model_mem_allocation act_mem_activations.update(activation_mem_allocation, iteration) if torch.isnan(loss): raise ValueError("Loss became NaN during iteration {}".format(iteration)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time()-end) prec1 = accuracy(score.data, label) losses.update(loss.item(), x.size(0)) top1.update(prec1[0][0], x.size(0)) if iteration % disp_iter == 0: act = '' if model_mem_allocation is not None: act = 'ActMem {act.val:.3f} ({act.avg:.3f})'.format(act=act_mem_activations) logger.info('iteration: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' '{act}' .format(iteration, max_iterations, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, act=act)) if iteration % disp_iter == 0: writer.add_scalar('train_loss', loss.item(), iteration) writer.add_scalar('train_acc', prec1[0][0], iteration) losses.reset() top1.reset() data_time.reset() batch_time.reset() if use_cuda: writer.add_scalar('act_mem_allocation', act_mem_activations.avg, iteration) act_mem_activations.reset() if iteration % valid_iter == 0: test_top1, test_loss = validate(model, ceriterion, test_loader, device=device) writer.add_scalar('test_loss', test_loss, iteration) writer.add_scalar('test_acc', test_top1, iteration) if iteration % save_iter == 0: manager.save_train_state(iteration) writer.flush() end = time.time() writer.close() ================================================ FILE: memcnn/trainers/tests/__init__.py ================================================ ================================================ FILE: memcnn/trainers/tests/resources/experiments.json ================================================ { "testsetup": { "model": "memcnn.trainers.tests.test_train.DummyModel", "model_params": { "block":"memcnn.trainers.tests.test_train.DummyDataset" }, "optimizer": "torch.optim.SGD", "optimizer_params": { "lr":0.1 }, "trainer": "memcnn.trainers.tests.test_train.dummy_trainer", "trainer_params": { "loss":"memcnn.trainers.tests.test_train.DummyDataset" }, "data_loader": "memcnn.trainers.tests.test_train.dummy_dataloaders", "data_loader_params": { "dataset": "memcnn.trainers.tests.test_train.DummyDataset", "workers": 0 } }, "resnet32": { "data_loader_params": { "batch_size": 100, "max_epoch": 80000 }, "model": "memcnn.models.resnet.ResNet", "model_params": { "block":"memcnn.models.resnet.BasicBlock", "layers":[5, 5, 5], "channels_per_layer":[16,16,32,64], "strides":[1, 1, 2, 2], "init_max_pool":false, "init_kernel_size":3, "batch_norm_fix":false }, "optimizer": "torch.optim.SGD", "optimizer_params": { "lr":0.1, "momentum":0.9, "weight_decay":2e-4 }, "trainer":"memcnn.trainers.classification.train", "trainer_params":{ "loss":"memcnn.utils.loss.CrossEntropyLossTF" } }, "resnet110": { "base": "resnet32", "model_params": { "layers":[18, 18, 18] } }, "resnet164": { "base": "resnet110", "model_params": { "block":"memcnn.models.resnet.Bottleneck" } }, "revnet38": { "base": "resnet32", "model_params": { "layers":[3, 3, 3], "channels_per_layer":[32,32,64,112], "block":"memcnn.models.resnet.RevBasicBlock" } }, "revnet110": { "base": "revnet38", "model_params": { "layers":[9, 9, 9], "channels_per_layer":[32,32,64,128] } }, "revnet164": { "base": "revnet110", "model_params": { "block":"memcnn.models.resnet.RevBottleneck" } }, "cifar10": { "data_loader": "memcnn.data.cifar.get_cifar_data_loaders", "data_loader_params": { "dataset": "torchvision.datasets.CIFAR10", "workers": 16 }, "model_params": { "num_classes":10 } }, "cifar100": { "data_loader": "memcnn.data.cifar.get_cifar_data_loaders", "data_loader_params": { "dataset": "torchvision.datasets.CIFAR100", "workers": 16 }, "model_params": { "num_classes":100 } }, "epoch5": { "data_loader_params": { "max_epoch": 5 } } } ================================================ FILE: memcnn/trainers/tests/test_classification.py ================================================ import pytest from memcnn.trainers.classification import train from memcnn.experiment.manager import ExperimentManager from memcnn.data.cifar import get_cifar_data_loaders from memcnn.utils.loss import CrossEntropyLossTF import torch from torchvision.datasets.cifar import CIFAR10 class SimpleTestingModel(torch.nn.Module): def __init__(self, klasses): super(SimpleTestingModel, self).__init__() self.conv = torch.nn.Conv2d(3, klasses, 1) self.avgpool = torch.nn.AvgPool2d(32) self.klasses = klasses def forward(self, x): return self.avgpool(self.conv(x)).reshape(x.shape[0], self.klasses) def test_train(tmp_path): expdir = str(tmp_path / "testexp") tmp_data_dir = str(tmp_path / "tmpdata") num_klasses = 10 model = SimpleTestingModel(num_klasses) optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01) manager = ExperimentManager(expdir, model, optimizer) manager.make_dirs() train_loader, test_loader = get_cifar_data_loaders(CIFAR10, tmp_data_dir, 40000, 2, 0) loss = CrossEntropyLossTF() train(manager, train_loader, test_loader, start_iter=39999, disp_iter=1, save_iter=1, valid_iter=1, use_cuda=False, loss=loss) def test_train_with_nan_loss(tmp_path): class NanLoss(torch.nn.Module): def __init__(self): super(NanLoss, self).__init__() def forward(self, Ypred, Y, W=None): return Ypred.mean() * float('nan') expdir = str(tmp_path / "testexp") tmp_data_dir = str(tmp_path / "tmpdata") num_klasses = 10 model = SimpleTestingModel(num_klasses) optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01) manager = ExperimentManager(expdir, model, optimizer) manager.make_dirs() train_loader, test_loader = get_cifar_data_loaders(CIFAR10, tmp_data_dir, 40000, 2, 0) loss = NanLoss() with pytest.raises(ValueError) as e: train(manager, train_loader, test_loader, start_iter=1, disp_iter=1, save_iter=1, valid_iter=1, use_cuda=False, loss=loss) assert "Loss became NaN during iteration" in str(e.value) ================================================ FILE: memcnn/trainers/tests/test_train.py ================================================ import json import pytest import os import sys import torch from memcnn.experiment.manager import ExperimentManager from memcnn.train import run_experiment, main try: from pathlib2 import Path except ImportError: from pathlib import Path def test_main(tmp_path): sys.argv = ['train.py', 'cifar10', 'resnet34', '--fresh', '--no-cuda', '--workers=0'] data_dir = str(tmp_path / "tmpdata") results_dir = str(tmp_path / "resdir") os.makedirs(data_dir) os.makedirs(results_dir) with pytest.raises(KeyError): main(data_dir=data_dir, results_dir=results_dir) def dummy_dataloaders(*args, **kwargs): return None, None def dummy_trainer(manager, *args, **kwargs): manager.save_train_state(2) class DummyDataset(object): def __init__(self, *args, **kwargs): pass class DummyModel(torch.nn.Module): def __init__(self, block): super(DummyModel, self).__init__() self.block = block self.conv = torch.nn.Conv2d(1, 1, 1) def forward(self, x): return self.conv(x) def test_run_experiment(tmp_path): exptags = ['testsetup'] exp_file = str(Path(__file__).parent / "resources" / "experiments.json") data_dir = str(tmp_path / "tmpdata") results_dir = str(tmp_path / "resdir") run_params = dict( experiment_tags=exptags, data_dir=data_dir, results_dir=results_dir, start_fresh=True, use_cuda=False, workers=None, experiments_file=exp_file ) with pytest.raises(RuntimeError): run_experiment(**run_params) os.makedirs(data_dir) with pytest.raises(RuntimeError): run_experiment(**run_params) os.makedirs(results_dir) run_experiment(**run_params) run_params["start_fresh"] = False run_experiment(**run_params) @pytest.mark.parametrize("network", [ pytest.param(network, marks=pytest.mark.skipif( condition=("FULL_NETWORK_TESTS" not in os.environ) and ("revnet38" != network), reason="Too memory intensive for CI so these tests are disabled by default. " "Set FULL_NETWORK_TESTS environment variable to enable the tests.") ) for network in ["resnet32", "resnet110", "resnet164", "revnet38", "revnet110", "revnet164"] ]) @pytest.mark.parametrize("use_cuda", [ False, pytest.param(True, marks=pytest.mark.skipif(condition=not torch.cuda.is_available(), reason="No GPU available")) ]) def test_train_networks(tmp_path, network, use_cuda): exptags = ["cifar10", network, "epoch5"] exp_file = str(Path(__file__).parent / "resources" / "experiments.json") data_dir = str(tmp_path / "tmpdata") results_dir = str(tmp_path / "resdir") os.makedirs(data_dir) os.makedirs(results_dir) run_experiment(experiment_tags=exptags, data_dir=data_dir, results_dir=results_dir, start_fresh=True, use_cuda=use_cuda, workers=None, experiments_file=exp_file, disp_iter=1, save_iter=5, valid_iter=5,) experiment_dir = os.path.join(results_dir, '_'.join(exptags)) assert os.path.exists(experiment_dir) manager = ExperimentManager(experiment_dir) scalars_file = os.path.join(manager.log_dir, "scalars.json") assert os.path.exists(scalars_file) with open(scalars_file, "r") as f: results = json.load(f) # no results should hold any NaN values assert not any([val != val for t, i, val in results["train_loss"]]) ================================================ FILE: memcnn/utils/__init__.py ================================================ ================================================ FILE: memcnn/utils/log.py ================================================ import os import json import logging import sys import time def setup(use_stdout=True, filename=None, log_level=logging.DEBUG): """setup some basic logging""" log = logging.getLogger('') log.setLevel(log_level) fmt = logging.Formatter("%(asctime)s [%(name)-15s] %(message)s", datefmt="%y-%m-%d %H:%M:%S") if use_stdout: ch = logging.StreamHandler(sys.stdout) ch.setLevel(log_level) ch.setFormatter(fmt) log.addHandler(ch) if filename is not None: fh = logging.FileHandler(filename) fh.setLevel(log_level) fh.setFormatter(fmt) log.addHandler(fh) class SummaryWriter(object): def __init__(self, log_dir): self._log_dir = log_dir self._log_file = os.path.join(log_dir, "scalars.json") self._summary = {} self._load_if_exists() def _load_if_exists(self): if os.path.exists(self._log_file): with open(self._log_file, "r") as f: self._summary = json.load(f) def add_scalar(self, name, value, iteration): if name not in self._summary: self._summary[name] = [] self._summary[name].append([time.time(), int(iteration), float(value)]) def flush(self): with open(self._log_file, "w") as f: json.dump(self._summary, f) def close(self): self.flush() ================================================ FILE: memcnn/utils/loss.py ================================================ import torch import torch.nn as nn from torch.nn.modules.module import Module def _assert_no_grad(variable): msg = "nn criterions don't compute the gradient w.r.t. targets - please " \ "mark these variables as not requiring gradients" assert not variable.requires_grad, msg # nosec class CrossEntropyLossTF(Module): def __init__(self): super(CrossEntropyLossTF, self).__init__() def forward(self, Ypred, Y, W=None): _assert_no_grad(Y) lsm = nn.Softmax(dim=1) y_onehot = torch.zeros(Ypred.shape[0], Ypred.shape[1], dtype=torch.float32, device=Ypred.device) y_onehot.scatter_(1, Y.data.view(-1, 1), 1) if W is not None: y_onehot = y_onehot * W return torch.mean(-y_onehot * torch.log(lsm(Ypred))) * Ypred.shape[1] ================================================ FILE: memcnn/utils/stats.py ================================================ """ Module containing utilities to compute statistics Some bits from: https://gist.github.com/xmfbit/67c407e34cbaf56e7820f09e774e56d8 """ class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count # top-k accuracy def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(dim=0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res ================================================ FILE: memcnn/utils/tests/__init__.py ================================================ ================================================ FILE: memcnn/utils/tests/test_log.py ================================================ import logging from memcnn.utils.log import setup, SummaryWriter def test_setup(tmp_path): logfile = str(tmp_path / 'testlog.log') setup(use_stdout=True, filename=logfile, log_level=logging.DEBUG) def test_summary_writer(tmp_path): logfile = tmp_path / 'scalars.json' assert not logfile.exists() writer = SummaryWriter(log_dir=str(tmp_path)) writer.add_scalar("test_value", 0.5, 1) writer.add_scalar("test_value", 2.5, 2) writer.add_scalar("test_value2", 123, 1) writer.flush() assert logfile.exists() writer = SummaryWriter(log_dir=str(tmp_path)) assert "test_value" in writer._summary assert "test_value2" in writer._summary assert len(writer._summary["test_value"]) == 2 writer.add_scalar("test_value", 123.4, 3) writer.close() writer = SummaryWriter(log_dir=str(tmp_path)) assert "test_value" in writer._summary assert "test_value2" in writer._summary assert len(writer._summary["test_value"]) == 3 ================================================ FILE: memcnn/utils/tests/test_loss.py ================================================ import torch from memcnn.utils.loss import _assert_no_grad, CrossEntropyLossTF def test_assert_no_grad(): data = torch.ones(3, 3, 3) data.requires_grad = False _assert_no_grad(data) def test_crossentropy_tf(): batch_size = 5 shape = (batch_size, 2) loss = CrossEntropyLossTF() ypred = torch.ones(*shape) ypred.requires_grad = True y = torch.ones(batch_size, dtype=torch.int64) y.requires_grad = False w = torch.ones(*shape) w.requires_grad = False w2 = torch.zeros(*shape) w2.requires_grad = False out1 = loss(ypred, y) assert len(out1.shape) == 0 out2 = loss(ypred, y, w) assert len(out2.shape) == 0 out3 = loss(ypred, y, w2) assert out3 == 0 assert len(out3.shape) == 0 ================================================ FILE: memcnn/utils/tests/test_stats.py ================================================ import pytest import torch from memcnn.utils.stats import AverageMeter, accuracy @pytest.mark.parametrize('val,n', [(1, 1), (14, 10), (10, 14), (5, 1), (1, 5), (0, 10)]) def test_average_meter(val, n): meter = AverageMeter() assert meter.val == 0 assert meter.avg == 0 assert meter.sum == 0 assert meter.count == 0 meter.update(val, n=n) assert meter.val == val assert meter.avg == val assert meter.sum == val * n assert meter.count == n @pytest.mark.parametrize('topk,klass', [((1,), 4), ((1, 3,), 2), ((5,), 1)]) def test_accuracy(topk, klass, num_klasses=5): # output, target, batch_size = 5 target = torch.ones(batch_size, dtype=torch.long) * klass output = torch.zeros(batch_size, num_klasses) output[:, klass] = 1 res = accuracy(output, target, topk) assert len(res) == len(topk) assert all([e == 100.0 for e in res]) ================================================ FILE: paper/README ================================================ The paper can be compiled locally using the following command: pandoc paper.md --bibliography paper.bib -o paper_local.pdf ================================================ FILE: paper/paper.bib ================================================ @misc{Gomez17, author = {A. N. Gomez and M. Ren and R. Urtasun and R. B. Grosse}, title = {The Reversible Residual Network: Backpropagation Without Storing Activations}, howpublished = {{\tt arXiv:1707.04585 [cs.CV]}}, url = {https://arxiv.org/abs/1707.04585}, year = 2017 } @misc{Dinh14, author = {L. Dinh and D. Krueger and Y. Bengio}, title = {{NICE:} Non-linear Independent Components Estimation}, howpublished = {{\tt arXiv:1410.8516 [cs.LG]}}, url = {https://arxiv.org/abs/1410.8516}, year = 2014 } @article{He2016, title = {Identity Mappings in Deep Residual Networks}, note = {{\tt arXiv:1603.05027 [cs.CV]}}, year = 2016, doi = {10.1007/978-3-319-46493-0_38}, publisher = {Springer International Publishing}, pages = {630--645}, author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, booktitle = {Computer Vision {\textendash} {ECCV} 2016} } @article{He2015, note = {{\tt arXiv:1512.03385 [cs.CV]}}, doi = {10.1109/cvpr.2016.90}, year = {2016}, month = jun, publisher = {{IEEE}}, author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun}, title = {Deep Residual Learning for Image Recognition}, booktitle = {2016 {IEEE} Conference on Computer Vision and Pattern Recognition ({CVPR})} } @misc{Chang17, author = {B. Chang and L. Meng and E. Haber and L. Ruthotto and D. Begert and E. Holtham}, title = {Reversible Architectures for Arbitrarily Deep Residual Neural Networks}, howpublished = {{\tt arXiv:1709.03698 [cs.CV]}}, url = {https://arxiv.org/abs/1709.03698}, year = 2017 } @mastersthesis{krizhevsky2009learning, author = {Krizhevsky, A.}, title = {Learning Multiple Layers of Features from Tiny Images}, school = {University of Toronto}, year = 2009, address = {Toronto, Ontario, Canada}, month = apr, } @inproceedings{imagenet_cvpr09, doi = {10.1109/cvprw.2009.5206848}, year = 2009, month = jun, publisher = {{IEEE}}, author = {Jia Deng and Wei Dong and Richard Socher and Li-Jia Li and Kai Li and Li Fei-Fei}, title = {{ImageNet}: A large-scale hierarchical image database}, booktitle = {2009 {IEEE} Conference on Computer Vision and Pattern Recognition} } @inproceedings{jaco18, author = {J.-H. Jacobsen and A.W.M. Smeulders and E. Oyallon}, title = {{i-RevNet}: Deep Invertible Networks}, booktitle = {ICLR}, year = {2018}, url = {https://arxiv.org/abs/1802.07088}, howpublished = {{\tt arXiv:1802.07088 [cs.LG]}} } @misc{TF2015, title={{TensorFlow}: Large-Scale Machine Learning on Heterogeneous Systems}, note={Software available from tensorflow.org}, author={ M.~Abadi and A.~Agarwal and P.~Barham and E.~Brevdo and Z.~Chen and C.~Citro and G.~S..~Corrado and A.~Davis and J.~Dean and M.~Devin and S.~Ghemawat and I.~Goodfellow and A.~Harp and G.~Irving and M.~Isard and Y.Jia and R.~Jozefowicz and L.~Kaiser and M.~Kudlur and J.~Levenberg and D.~Man\'{e} and R.~Monga and S.~Moore and D.~Murray and C.~Olah and M.~Schuster and J.~Shlens and B.~Steiner and I.~Sutskever and K.~Talwar and P.~Tucker and V.~Vanhoucke and V.~Vasudevan and F.~Vi\'{e}gas and O.~Vinyals and P.~Warden and M.~Wattenberg and M.~Wicke and Y.~Yu and X.~Zheng}, year={2015}, url={http://tensorflow.org/}, } @inproceedings{paszke2017automatic, title={Automatic differentiation in {PyTorch}}, author={Paszke, A. and Gross, S. and Chintala, S. and Chanan, G. and Yang, E. and DeVito, Z. and Lin, Z. and Desmaison, A. and Antiga, L. and Lerer, A.}, booktitle={NIPS-W}, year={2017}, howpublished = {\url{https://openreview.net/forum?id=BJJsrmfCZ}}, } @inproceedings{kingma2018glow, title = {Glow: Generative Flow with Invertible 1x1 Convolutions}, author = {Kingma, Durk P and Dhariwal, Prafulla}, booktitle = {Advances in Neural Information Processing Systems 31}, editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett}, pages = {10215--10224}, year = {2018}, publisher = {Curran Associates, Inc.}, url = {http://papers.nips.cc/paper/8224-glow-generative-flow-with-invertible-1x1-convolutions.pdf} } @misc{dinh2016density, title={Density estimation using {Real NVP}}, author={Dinh, Laurent and Sohl-Dickstein, Jascha and Bengio, Samy}, year={2016}, howpublished = {{\tt arXiv:1605.08803 [cs.LG]}}, url = {https://arxiv.org/abs/1605.08803} } @incollection{martens2012training, doi = {10.1007/978-3-642-35289-8_27}, year = 2012, publisher = {Springer Berlin Heidelberg}, pages = {479--535}, author = {James Martens and Ilya Sutskever}, title = {Training Deep and Recurrent Networks with Hessian-Free Optimization}, booktitle = {Neural Networks: Tricks of the Trade: Second Edition} } @misc{chen2016training, title={Training deep nets with sublinear memory cost}, author={Chen, Tianqi and Xu, Bing and Zhang, Chiyuan and Guestrin, Carlos}, howpublished = {{\tt arXiv:1604.06174 [cs.LG]}}, url = {https://arxiv.org/abs/1604.06174}, year=2016 } @InProceedings{Ouderaa_2019_CVPR, author = {Ouderaa, Tycho F.A. van der and Worrall, Daniel E.}, title = {Reversible {GANs} for Memory-Efficient Image-To-Image Translation}, booktitle = {{The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}}, month = jun, year = 2019, howpublished = {{\tt arXiv:1902.02729 [cs.CV]}}, url = {https://arxiv.org/abs/1902.02729}, } @inproceedings{zhu2017unpaired, title={Unpaired image-to-image translation using cycle-consistent adversarial networks}, author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, booktitle={{Proceedings of the IEEE International Conference on Computer Vision}}, pages={2223--2232}, year={2017}, doi={10.1109/iccv.2017.244}, } @inproceedings{ouderaa:MIDLAbstract2019a, title={Chest {CT} Super-resolution and Domain-adaptation using Memory-efficient 3D Reversible {GAN}s}, author={Tycho F.A. van der Ouderaa and Daniel E. Worrall and Bram van Ginneken}, booktitle={International Conference on Medical Imaging with Deep Learning}, address={London, United Kingdom}, year=2019, month=jul, url={https://openreview.net/forum?id=SkxueFsiFV} } ================================================ FILE: paper/paper.md ================================================ --- title: 'MemCNN: A Python/PyTorch package for creating memory-efficient invertible neural networks' tags: - MemCNN - Python - PyTorch - machine learning - invertible networks - deep learning authors: - name: Sil C. van de Leemput orcid: 0000-0001-6047-3051 affiliation: 1 - name: Jonas Teuwen affiliation: 1 - name: Bram van Ginneken affiliation: 1 - name: Rashindra Manniesing affiliation: 1 affiliations: - name: Radboud University Medical Center, Department of Radiology and Nuclear Medicine, Nijmegen, The Netherlands index: 1 date: 28 June 2019 bibliography: paper.bib --- # Summary Neural networks are computational models that were originally inspired by biological neural networks like animal brains. These networks are composed of many small computational units called neurons that perform elementary calculations. Instead of explicitly programming the behavior of neural networks, these models can be trained to perform tasks, like classifying images, by presenting them examples. Sufficiently complex neural networks can automatically extract task-relevant characteristics from the presented examples without having prior knowledge about the task domain, which makes them attractive for many complicated real-world applications. Reversible operations have recently been successfully applied to classification problems to reduce memory requirements during neural network training. This feature is accomplished by removing the need to store the input activation for computing the gradients at the backward pass and instead reconstruct them on demand. However, current approaches rely on custom implementations of backpropagation, which limits applicability and extendibility. We present MemCNN, a novel PyTorch framework that simplifies the application of reversible functions by removing the need for a customized backpropagation. The framework contains a set of practical generalized tools, which can wrap common operations like convolutions and batch normalization and which take care of memory management. We validate the presented framework by reproducing state-of-the-art experiments using MemCNN and by comparing classification accuracy and training time on Cifar-10 and Cifar-100. Our MemCNN implementations achieved similar classification accuracy and faster training times while retaining compatibility with the default backpropagation facilities of PyTorch. # Background Reversible functions, which allow exact retrieval of its input from its output, can reduce memory overhead when used within the context of training neural networks using backpropagation. That is since only the output requires to be stored, intermediate feature maps can be freed on the forward pass and recomputed from the output on the backward pass when required. Recently, reversible functions have been used with some success to extend the well established residual network (ResNet) for image classification from @He2015 to more memory efficient invertible convolutional neural networks [@Gomez17; @Chang17; @jaco18] showing competitive performance on datasets like Cifar-10, Cifar-100 [@krizhevsky2009learning] and ImageNet [@imagenet_cvpr09]. However, practical applicability and extendibility of reversible functions for the reduction of memory overhead have been limited, since current implementations require customized backpropagation, which does not work conveniently with modern deep learning frameworks and requires substantial manual design. The reversible residual network (RevNet) of @Gomez17 is a variant on ResNet, which hooks into its sequential structure of residual blocks and replaces them with reversible blocks, that creates an explicit inverse for the residual blocks based on the equations from @Dinh14 on nonlinear independent components estimation. The reversible block takes arbitrary nonlinear functions $\mathcal{F}$ and $\mathcal{G}$ and renders them invertible. Their experiments show that RevNet scores similar classification performance on Cifar-10, Cifar-100, and ImageNet, with less memory overhead. Reversible architectures like RevNet have subsequently been studied in the framework of ordinary differential equations (ODE) [@Chang17]. Three reversible neural networks based on Hamiltonian systems are proposed, which are similar to the RevNet, but have a specific choice for the nonlinear functions $\mathcal{F}$ and $\mathcal{G}$ which are shown stable during training within the ODE framework on Cifar-10 and Cifar-100. The i-RevNet architecture extends the RevNet architecture by also making the downscale operations invertible [@jaco18], effectively creating a fully invertible architecture up until the last layer, while still showing good classification accuracy compared to ResNet on ImageNet. One particularly interesting finding shows that bottlenecks are not a necessary condition for training neural networks, which shows that the study of invertible networks can lead to a better understanding of neural network training in general. The different reversible architectures proposed in the literature [@Gomez17; @Chang17; @jaco18] have all been modifications of the ResNet architecture and all have been implemented in TensorFlow [@TF2015]. However, these implementations rely on custom backpropagation, which limits creating novel invertible networks and application of the concepts beyond the application architecture. Our proposed framework MemCNN overcomes this issue by being compatible with the default backpropagation facilities of PyTorch. Furthermore, PyTorch offers convenient features over other deep learning frameworks like a dynamic computation graph and simple inspection of gradients during backpropagation, which facilitates inspection of invertible operations in neural networks. # Methods ## The reversible block The core operator of MemCNN is the reversible block which is an operator which takes a function $f$ and outputs a function $R : X \to Y$, and an inverse function $R^{-1} : Y \rightarrow{X}$ which resembles an invertible version of $f$. Here, $x\in X$ and $y\in Y$ can be arbitrary tensors with the same size and number of dimension, i.e.: $\operatorname{shape}(x)=\operatorname{shape}(y)$. Additionally, it must be possible to partition the input $x=(x_1, x_2)$ and output tensors $y=(y_1, y_2)$ in half, where each partition has the same shape, i.e.: $\operatorname{shape}(x_1) = \operatorname{shape}(x_2) = \operatorname{shape}(y_1) = \operatorname{shape}(y_2)$. Formally, the reversible block operation (1), its inverse (2), and its partition constraints (3) provide a sufficiently general framework for implementing reversible operations. For example, if one wants to create a reversible block performing a convolution followed by a ReLu $f$, the input $x \in X$ is partitioned in $(x_1, x_2)$ of equal sizes to which this convolution block $f$ is applied twice (say $\mathcal{F}$ and $\mathcal{G}$). The Reversible Block takes these two operators ($\mathcal{F}$ and $\mathcal{G}$) and outputs a "resblock"-like version $R$ of the operator and an explicit inverse $R^{-1}$. Effectively the learnable function $f$ is replaced by a learnable approximation $R$ with an explicit inverse $R^{-1}$. \begin{equation} \quad R(x) = y \end{equation} \begin{equation} R^{-1}(y) = x \end{equation} with \begin{equation} \operatorname{shape}(x_1) = \operatorname{shape}(x_2) = \operatorname{shape}(y_1) = \operatorname{shape}(y_2) \end{equation} ## Couplings Using the above definitions we provide two different implementations for the reversible block in MemCNN, which we will call `couplings'. A coupling provides a reversible mapping from $(x_1, x_2)$ to $(y_1, y_2)$. MemCNN supports two couplings: the additive coupling and the affine coupling. ### Additive coupling Equation 4 represents the additive coupling, which follows the equations of @Dinh14 and @Gomez17. These support a reversible implementation through arbitrary (nonlinear) functions $\mathcal{F}$ and $\mathcal{G}$. These functions can be convolutions, ReLus, etc., as long as they have matching input and output shapes. The additive coupling is obtained by first computing $y_1$ from input partitions $x_1, x_2$ and function $\mathcal{F}$ and subsequently $y_2$ is computed from partitions $y_1, x_2$ and function $\mathcal{G}$. Next, (4) can be rewritten to obtain an exact inverse function as shown in (5). Figure 1 shows a graphical representation of the additive coupling and its inverse. ![Graphical representation of additive coupling. The left graph shows the forward computations and the right graph shows its inverse. First, input $x_1$ and $\mathcal{F}(x_2)$ are added to form $y_1$, next $x_2$ and $\mathcal{G}(y_1)$ are added to form $y_2$. Going backwards, first, $\mathcal{G}(y_1)$ is subtracted from $y_2$ to obtain $x_2$; subsequently, $\mathcal{F}(x_2)$ is subtracted from $y_1$ to obtain $x_1$. Here, $+$ and $-$ stand for respectively element-wise summation and element-wise subtraction.](additive_005.pdf) \begin{equation}\label{eq:additiveforward} \begin{split} y_1 &= x_1 + \mathcal{F}(x_2), \\ y_2 &= x_2 + \mathcal{G}(y_1) \\ \end{split} \end{equation} \begin{equation}\label{eq:additivebackward} \begin{split} x_2 &= y_2 - \mathcal{G}(y_1), \\ x_1 &= y_1 - \mathcal{F}(x_2) \end{split} \end{equation} ### Affine coupling Equation (6) gives the affine coupling, introduced by @dinh2016density and later used by @kingma2018glow, which is more expressive than the additive coupling. The affine coupling, similar to the additive coupling, supports a reversible implementations through arbitrary (nonlinear) functions $\mathcal{F}$ and $\mathcal{G}$. It also first computes $y_1$ from input partitions $x_1, x_2$ and function $\mathcal{F}$ and subsequently it computes $y_2$ from partitions $y_1, x_2$ and function $\mathcal{G}$. The difference with the additive coupling is that now the functions $\mathcal{F}=(s,t)$ and $\mathcal{G}=(s',t')$ each produce two equally sized partitions for scaling and translation, so $\operatorname{shape}(x_1) = \operatorname{shape}(s) = \operatorname{shape}(t) = \operatorname{shape}(s') = \operatorname{shape}(t')$ holds. These components are then used to compute the output using element-wise product ($\odot$) and element-wise exponentiation with base $e$ and element-wise addition ($+$). Equation (6) can be rewritten to obtain an exact inverse function as shown in (7), which uses element-wise division ($/$) and element-wise subtraction ($-$). Figure 2 shows a graphical representation of the affine coupling and its inverse. ![Graphical representation of the affine coupling. The left graph shows the forward computations and the right graph shows its inverse. Here, $\odot, /, +, -,$ and $e$ stand for element-wise multiplication, element-wise division, element-wise addition, element-wise subtraction, and element-wise exponentiation with base $e$ respectively. First, $s, t$ are computed for $\mathcal{F}(x_2)$, next input $x_1$ is element-wise multiplied with $e^{s}$ and added to $t$ to form $y_1$, subsequently $s', t'$ are computed for $\mathcal{G}(y_1)$ and then $x_2$ is element-wise multiplied with $e^{s'}$ and added to $t'$ to form $y_2$.](affine_005.pdf) \begin{equation}\label{eq:affineforward} \begin{split} y_1 &= x_1 \odot e^{s} + t \;\;\;\;\, \text{with} \;\; \mathcal{F}(x_2) = (s, t) \\ y_2 &= x_2 \odot e^{s'} + t' \;\;\, \text{with} \;\;\; \mathcal{G}(y_1) = (s', t') \end{split} \end{equation} \begin{equation}\label{eq:affinebackward} \begin{split} x_2 &= (y_2 - t') / e^{s'} \;\;\, \text{with} \;\;\; \mathcal{G}(y_1) = (s', t') \\ x_1 &= (y_1 - t) / e^{s} \;\;\;\:\: \text{with} \;\; \mathcal{F}(x_2) = (s, t) \end{split} \end{equation} ## Implementation details The reversible block has been implemented as a \texttt{torch.nn.Module} which wraps other PyTorch modules of arbitrary complexity for coupling functions $\mathcal{F}$ and $\mathcal{G}$. Each memory saving coupling is implemented using at least one \texttt{torch.autograd.Function}, which provides a custom forward and backward pass that works with the automatic differentiation system of PyTorch. Memory savings are implemented at the level of the reversible block and are achieved by setting the size of the underlying tensor storage to zero for inputs on the forward pass and restoring the storage size to the original size on the backward pass once it is required for computing gradients. ## Building larger networks The reversible block $R$ can be chained by subsequent reversible blocks, e.g.: $R_3 \circ R_2 \circ R_1$ for reversible blocks $R_1, R_2, R_3$, which creates a fully reversible chain of operations (see Figure 3). Additionally, reversible blocks can be mixed with regular functions $f$, e.g. $f \circ R$ or $R \circ f$ for reversible block $R$ and regular function $f$. Note that mixing regular functions with reversible blocks often breaks the invertibility of reversible chains. ![Graphical representation of chaining multiple reversible block layers. ](coupling_001.pdf) ## Memory savings **Table 1:** Comparison of memory and computational complexity for training a residual network (ResNet) between various memory saving techniques (extended table from @Gomez17). $L$ depicts the number of residual layers in the ResNet. \begin{center} \vspace{0.3cm} \begin{tabular}{llp{2.5cm}p{2.5cm}} \hline \textbf{Technique} & \textbf{Authors} & \textbf{Memory Complexity} & \textbf{Computational Complexity} \\ \hline Naive & & $O(L)$ & $O(L)$ \\ Checkpointing & Martens et al. (2012) & $O(\sqrt{L})$ & $O(L)$ \\ Recursive & Chen et al. (2016) & $O(\log L)$ & $O(L \log L)$ \\ Additive coupling & Gomez et al. (2017) & $O(1)$ & $O(L)$ \\ Affine coupling & Dinh et al. (2016) & $O(1)$ & $O(L)$ \\ \hline \end{tabular} \vspace{0.6cm} \end{center} The reversible block model has an advantageous memory footprint when chained in a sequence when training neural networks. After computing each $R(x) = y$ by (1) on the forward pass, input $x$ can be freed from memory and be recomputed on the backward pass, using the inverse function $R^{-1}(y)=x$ from (2). Once the input is restored, the gradients for the weights and the inputs can be recomputed as normal using the PyTorch `autograd' solver. This effectively yields a memory complexity of $O(1)$ in the number of chained reversible blocks. Table 1 shows a comparison of memory versus computational complexity for different memory saving techniques. \break # Experiments and results **Table 2a:** Accuracy comparison of the PyTorch implementation (MemCNN) versus the Tensorflow implementation from @Gomez17 on Cifar-10 and Cifar-100 [@krizhevsky2009learning]. Accuracies were approximately similar between implementations. \begin{center} \vspace{0.3cm} \begin{tabular}{lcccc} \hline & \multicolumn{2}{c}{\textbf{Cifar-10}} & \multicolumn{2}{c}{\textbf{Cifar-100}} \\ \textbf{Model} & \textbf{Tensorflow} & \textbf{PyTorch} & \textbf{Tensorflow} & \textbf{PyTorch} \\ \hline ResNet-32 & 92.74 & 92.86 & 69.10 & 69.81 \\ ResNet-110 & 93.99 & 93.55 & 73.30 & 72.40 \\ ResNet-164 & 94.57 & 94.80 & 76.79 & 76.47 \\ RevNet-38 & 93.14 & 92.80 & 71.17 & 69.90 \\ RevNet-110 & 94.02 & 94.10 & 74.00 & 73.30 \\ RevNet-164 & 94.56 & 94.90 & 76.39 & 76.90 \\ \hline \end{tabular} \vspace{0.3cm} \end{center} **Table 2b:** Training time (in hours:minutes) comparison of the PyTorch implementation (MemCNN) versus the Tensorflow implementation from @Gomez17 on Cifar-10 and Cifar-100 [@krizhevsky2009learning]. Training times were significantly less for the PyTorch implementation than for the Tensorflow implementation. \begin{center} \vspace{0.3cm} \begin{tabular}{lcccc} \hline & \multicolumn{2}{c}{\textbf{Cifar-10}} & \multicolumn{2}{c}{\textbf{Cifar-100}} \\ \textbf{Model} & \textbf{Tensorflow} & \textbf{PyTorch} & \textbf{Tensorflow} & \textbf{PyTorch} \\ \hline ResNet-32 & \;\,\,2:04 & 1:51 & \;\,\,1:58 & 1:51 \\ ResNet-110 & \;\,\,4:11 & 2:51 & \;\,\,6:44 & 2:39 \\ ResNet-164 & 11:05 & 4:59 & 10:59 & 3:45 \\ RevNet-38 & \;\,\,2:17 & 2:09 & \;\,\,2:20 & 2:16 \\ RevNet-110 & \;\,\,6:59 & 3:42 & \;\,\,7:03 & 3:50 \\ RevNet-164 & 13:09 & 7:21 & 13:12 & 7:17 \\ \hline \end{tabular} \vspace{0.6cm} \end{center} To validate MemCNN, we reproduced the experiments from @Gomez17 on Cifar-10 and Cifar-100 [@krizhevsky2009learning] using their Tensorflow [@TF2015] implementation on GitHub\footnote{\url{https://github.com/renmengye/revnet-public}}, and made a direct comparison with our PyTorch implementation on accuracy and train time. We have tried to keep all the experimental settings, like data loading, loss function, train procedure, and training parameters, as similar as possible. All experiments were performed on a single NVIDIA GeForce GTX 1080 with 8GB of RAM. The accuracies and training time results are listed in respectively Table 2a and Table 2b. Model performance of our PyTorch implementation obtained similar accuracy to the TensorFlow implementation with less training time on Cifar-10 and Cifar-100. All models and experiments are included in MemCNN and can be rerun for reproducibility. Table 3 shows memory usage statistics (parameters and activations) during training for all PyTorch models. Here, the ResNet model uses a conventional implementation and the RevNet model uses the reversible blocks from MemCNN. The results show that significant activation memory reduction was obtained using the reversible block implementation (RevNet) when the number of layers of the models increased. \break **Table 3:** Model statistics for all PyTorch model implementations on memory usage (parameters and activations) in MB during training and the number of layers and parameters. The ResNet model was implemented using a conventional non-reversible implementation while the RevNet model uses MemCNN with memory saving reversible blocks. To facilitate comparison, each row lists the statistics of one ResNet and one RevNet model which have a comparable number of layers and number of parameters. Significant memory savings for the activations were observed when using reversible operations (RevNet) as the number of layers increased. Model parameter memory usage stayed roughly the same between implementations. \begin{center} \vspace{0.3cm} \begin{tabular}{rr rr rr rr} \hline \multicolumn{2}{c}{\textbf{Layers}} & \multicolumn{2}{c}{\textbf{Parameters}} & \multicolumn{2}{c}{\textbf{Parameters (MB)}} & \multicolumn{2}{c}{\textbf{Activations (MB)}} \\ \multicolumn{2}{c}{\textbf{ResNet RevNet}} & \multicolumn{2}{c}{\textbf{ResNet RevNet}} & \multicolumn{2}{c}{\textbf{ResNet RevNet}} & \multicolumn{2}{c}{\textbf{ResNet RevNet}} \\ \hline \quad 32 & 38 \quad \quad & 466906 & 573994 & \quad \enspace 1.9 & 2.3 \quad \enspace & \quad 238.6 & 85.6 \enspace \enspace \\ \quad 110 & 110 \quad \quad & 1730714 & 1854890 & \quad \enspace 6.8 & 7.3 \quad \enspace & \quad 810.7 & 85.7 \enspace \enspace \\ \quad 164 & 164 \quad \quad & 1704154 & 1983786 & \quad \enspace 6.8 & 7.9 \quad \enspace & \quad 2452.8 & 432.7 \enspace \enspace \\ \hline \end{tabular} \vspace{0.6cm} \end{center} # Works using MemCNN MemCNN has recently been used to create reversible GANs for memory-efficient image-to-image translation by @Ouderaa_2019_CVPR. Image-to-image translation considers the problem of mapping both $X \rightarrow Y$ and $Y \rightarrow X$ given two image domains $X$ and $Y$ using either paired or unpaired examples. In this work, the CycleGAN [@zhu2017unpaired] model has been enlarged and extended with an invertible core using the reversible block, which they call RevGAN. Since the invertible core is weight tied, training the model for the mapping $X \rightarrow Y$ automatically trains the model for mapping $Y \rightarrow X$. They show similar or increased performance of RevGAN with respect to similar non-invertible models like the CycleGAN with less memory overhead during training. The RevGAN model has also been applied to chest CT images [@ouderaa:MIDLAbstract2019a]. # Conclusion We have presented MemCNN, a novel PyTorch framework, for creating and applying reversible operations for neural networks. It shows similar accuracy on Cifar-10 and Cifar-100 datasets with the current state-of-the-art method for reversible operations in Tensorflow and provides overall faster training times. The main features of the framework are smooth integration of reversible functions with other non-reversible functions by removing the need for a custom backpropagation and simple wrapping of arbitrary complex non-invertible nonlinear functions. The presented framework is intended to facilitate the study and application of invertible functions in the context of neural networks. # Acknowledgements This work was supported by research grants from the Netherlands Organization for Scientific Research (NWO), the Netherlands and Canon Medical Systems Corporation, Japan. # References ================================================ FILE: requirements.txt ================================================ numpy SimpleITK torch>=1.0.0 torchvision tqdm pathlib2 ================================================ FILE: setup.cfg ================================================ [bumpversion] current_version = 1.5.2 commit = True tag = True tag_name = {new_version} [bumpversion:file:setup.py] search = VERSION = '{current_version}' replace = VERSION = '{new_version}' [bumpversion:file:memcnn/__init__.py] search = __version__ = '{current_version}' replace = __version__ = '{new_version}' [bdist_wheel] universal = 1 [flake8] exclude = docs [aliases] test = pytest [tool:pytest] collect_ignore = ['setup.py'] ================================================ FILE: setup.py ================================================ import os import sys from distutils.core import setup from setuptools.command.install import install from setuptools import find_packages # circleci.py version VERSION = '1.5.2' with open('README.rst', 'r') as fh: long_description = fh.read().split('Results\n-------')[0] with open('requirements.txt', 'r') as fh: requirements = [e.strip() for e in fh.readlines() if e.strip() != ''] class VerifyVersionCommand(install): """Custom command to verify that the git tag matches our version""" description = 'verify that the git tag matches our version' def run(self): tag = os.getenv('CIRCLE_TAG') if tag != VERSION: info = "Git tag: {0} does not match the version of this app: {1}".format( tag, VERSION ) sys.exit(info) setup( name='memcnn', version=VERSION, author='S.C. van de Leemput', author_email='silvandeleemput@gmail.com', packages=find_packages(), include_package_data=True, scripts=[], url='http://pypi.python.org/pypi/memcnn/', license='LICENSE.txt', description='A PyTorch framework for developing memory efficient deep invertible networks.', long_description=long_description, long_description_content_type='text/x-rst', install_requires=requirements, classifiers=[ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Medical Science Apps.", "Topic :: Scientific/Engineering :: Information Analysis", "Topic :: Software Development :: Libraries", "Operating System :: OS Independent" ], keywords='memcnn invertible PyTorch', cmdclass={ 'verify': VerifyVersionCommand, } ) ================================================ FILE: tox.ini ================================================ [tox] envlist={py38}-torch{10,11,14,17,latest},release,docs skipsdist=True [testenv] passenv=LC_ALL, LANG, HOME commands=pytest --cov=memcnn --cov-report=html --cov-report=xml --junitxml=test-reports/junit.xml deps= pip==19.1.1 numpy SimpleITK tqdm pytest pytest-cov torch14: torch==1.4.0 torch14: torchvision==0.5.0 torch17: torch==1.7.0 torch17: torchvision==0.8.1 torchlatest: torch torchlatest: torchvision [testenv:release] deps= bumpversion commands=bumpversion --dry-run minor # generate the sphinx doc [testenv:docs] basepython=python changedir=docs deps=-rdocsRequirements.txt commands=sphinx-build -b linkcheck -b html -d {envtmpdir}/doctrees . {envtmpdir}/html