Repository: inferno-pytorch/inferno Branch: master Commit: 789c6d00b34c Files: 179 Total size: 610.4 KB Directory structure: gitextract_cy4q7gy0/ ├── .editorconfig ├── .github/ │ └── ISSUE_TEMPLATE.md ├── .gitignore ├── .travis.yml ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── add2path.sh ├── build_docs.sh ├── conda-recipe/ │ ├── build.sh │ └── meta.yaml ├── docs/ │ ├── .gitignore │ ├── Makefile │ ├── _templates/ │ │ ├── layout.html │ │ └── template_module.rst │ ├── authors.rst │ ├── conf.py │ ├── contributing.rst │ ├── environment.yml │ ├── examples.rst │ ├── history.rst │ ├── index.rst │ ├── inferno-apidoc/ │ │ ├── inferno.extensions.containers.rst │ │ ├── inferno.extensions.criteria.rst │ │ ├── inferno.extensions.initializers.rst │ │ ├── inferno.extensions.layers.rst │ │ ├── inferno.extensions.metrics.rst │ │ ├── inferno.extensions.optimizers.rst │ │ ├── inferno.extensions.rst │ │ ├── inferno.io.box.rst │ │ ├── inferno.io.core.rst │ │ ├── inferno.io.rst │ │ ├── inferno.io.transform.rst │ │ ├── inferno.io.volumetric.rst │ │ ├── inferno.rst │ │ ├── inferno.trainers.callbacks.logging.rst │ │ ├── inferno.trainers.callbacks.rst │ │ ├── inferno.trainers.rst │ │ ├── inferno.utils.rst │ │ └── modules.rst │ ├── installation.rst │ ├── make.bat │ ├── readme.rst │ ├── refs.bib │ ├── usage.rst │ └── zbibliography.rst ├── examples/ │ ├── README.txt │ ├── plot_cheap_unet.py │ ├── plot_train_side_loss_unet.py │ ├── plot_unet_tutorial.py │ ├── regularized_mnist.py │ └── trainer.py ├── inferno/ │ ├── __init__.py │ ├── extensions/ │ │ ├── __init__.py │ │ ├── containers/ │ │ │ ├── __init__.py │ │ │ ├── graph.py │ │ │ └── sequential.py │ │ ├── criteria/ │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── elementwise_measures.py │ │ │ ├── regularized.py │ │ │ └── set_similarity_measures.py │ │ ├── initializers/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── presets.py │ │ ├── layers/ │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── convolutional.py │ │ │ ├── convolutional_blocks.py │ │ │ ├── device.py │ │ │ ├── identity.py │ │ │ ├── normalization.py │ │ │ ├── reshape.py │ │ │ └── sampling.py │ │ ├── metrics/ │ │ │ ├── __init__.py │ │ │ ├── arand.py │ │ │ ├── base.py │ │ │ ├── categorical.py │ │ │ ├── cremi_score.py │ │ │ └── voi.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── res_unet.py │ │ │ └── unet.py │ │ └── optimizers/ │ │ ├── __init__.py │ │ ├── adam.py │ │ ├── annealed_adam.py │ │ └── ranger.py │ ├── inferno.py │ ├── io/ │ │ ├── __init__.py │ │ ├── box/ │ │ │ ├── __init__.py │ │ │ ├── binary_blobs.py │ │ │ ├── camvid.py │ │ │ ├── cifar.py │ │ │ └── cityscapes.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── concatenate.py │ │ │ ├── data_utils.py │ │ │ └── zip.py │ │ ├── transform/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── generic.py │ │ │ ├── image.py │ │ │ └── volume.py │ │ └── volumetric/ │ │ ├── __init__.py │ │ ├── lazy_volume_loader.py │ │ ├── volume.py │ │ └── volumetric_utils.py │ ├── trainers/ │ │ ├── __init__.py │ │ ├── basic.py │ │ └── callbacks/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── console.py │ │ ├── essentials.py │ │ ├── gradients.py │ │ ├── logging/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── tensorboard.py │ │ ├── scheduling.py │ │ ├── tqdm.py │ │ └── tqdmstub.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── exceptions.py │ │ ├── io_utils.py │ │ ├── math_utils.py │ │ ├── model_utils.py │ │ ├── partial_cls.py │ │ ├── python_utils.py │ │ ├── test_utils.py │ │ ├── torch_utils.py │ │ └── train_utils.py │ └── version.py ├── readthedocs.yml ├── requirements.txt ├── requirements_dev.txt ├── setup.py └── tests/ ├── __init__.py ├── test_extensions/ │ ├── __init__.py │ ├── test_containers/ │ │ └── test_graph.py │ ├── test_criteria/ │ │ ├── test_core.py │ │ ├── test_elementwise_measures.py │ │ └── test_set_similarity_measures.py │ ├── test_layers/ │ │ ├── deprecated/ │ │ │ └── building_blocks.py │ │ ├── test_activations.py │ │ ├── test_convolutional.py │ │ ├── test_device.py │ │ └── test_reshape.py │ ├── test_metrics/ │ │ └── categorical.py │ └── test_models/ │ ├── __init__.py │ ├── test_res_unet.py │ └── test_unet.py ├── test_inferno.py ├── test_io/ │ ├── __init__.py │ ├── test_box/ │ │ ├── __init__.py │ │ ├── test_camvid.py │ │ └── test_cityscapes.py │ ├── test_core/ │ │ ├── __init__.py │ │ ├── test_concatenate.py │ │ └── test_zip.py │ └── test_volumetric/ │ ├── __init__.py │ ├── test_lazy_volume_loader.py │ └── test_volume_loader.py ├── test_training/ │ ├── __init__.py │ ├── test_basic.py │ └── test_callbacks/ │ ├── __init__.py │ ├── test_base.py │ ├── test_essentials.py │ ├── test_logging/ │ │ ├── __init__.py │ │ ├── test_base.py │ │ └── test_tensorboard.py │ └── test_scheduling.py └── test_utils/ ├── __init__.py ├── test_model_utils.py ├── test_partial_cls.py └── test_train_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .editorconfig ================================================ # http://editorconfig.org root = true [*] indent_style = space indent_size = 4 trim_trailing_whitespace = true insert_final_newline = true charset = utf-8 end_of_line = lf [*.bat] indent_style = tab end_of_line = crlf [LICENSE] insert_final_newline = false [Makefile] indent_style = tab ================================================ FILE: .github/ISSUE_TEMPLATE.md ================================================ * inferno version: * Python version: * Operating System: ### Description Describe what you were trying to get done. Tell us what happened, what went wrong, and what you expected to happen. ### What I Did ``` Paste the command(s) you ran and the output. If there was a crash, please include the traceback here. ``` ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log # Sphinx documentation docs/_build/ # PyBuilder target/ # pyenv python configuration file .python-version ================================================ FILE: .travis.yml ================================================ language: python dist: xenial python: - 3.7 env: - PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch install: - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - conda config --set always_yes yes --set changeps1 no - conda update -q conda - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION - source activate test-environment - conda install -c conda-forge networkx h5py scikit-image pyyaml dill tensorboardx - conda install -c pytorch $PYTORCH_CONDA - conda install -c $TORCHVISION_CHANNEL $TORCHVISION_CONDA deploy: provider: pypi user: nasimrahaman password: secure: !!binary | bWwzZitLcEpibHBaUWNhUVA4UUlGa2JsZDQxVkx3eFlkY1FiYWJqYkFvWm5pdDErRzlKRXZFM0hR ZE15V0tIWm5JQlJRSGlveXdYNjAzQVc1UFV3ZjNBOG0zc21vK3RaZjVSYnM5aE5ySE93ajBXc1N4 akNHNGhOSnF6UnBDY2kwakxPeWhxaEwxQkR0empSaFdJbWVlOE81RDVPY2pSdGw1TDQ3QjhwVGor TVREdlpSYTVFd2xNNXdadTJYWFVXL3ZQY0VLZE9xckFoVk5PSHpkTTh5MGM1S1lHaS9nNThVK2JO OVp5RkFROVpuOEY3YmxPdzBQZnAvL202ZUkxamlKSmxhaE13UU4zV2tJRWRpNklVSTE0RUp1ck5s Q28xL2kzNER0dGVkZzI0eVhULzcxRFl5Y0pZQWMrcWtoa1VVVUo4NEZKV3JjUjNqTnF5bVI3Ykty cFJrR3JydjV0dUpGUnBhc2NIdEdKVUswMkdJWEJUc3JJWGg4bS9oRGtMaVJaMExBeitJQWR4b2tF MzB0OWppZ0x5VXFSMmxnVmNvZERzRWZMRnJEMTBHeTJVS2FueVhlYmpsck9qK3V5S1dtZm5UTXg4 bGNzN09HWEZiUmo2K0ZuYTg5a00xN3poSXhzc3pSMnRGSVJwamV4a0gzZUpyZlpYY1daTFZ3QnV0 clUwZW10VEsxeGFmOGFjNTd3Wll1R3JXNEZJT1h2bmxoeS9pV0FMVlE4YnVFZFFjQnJ5YWFiRjUy RkZvZk1SUnp3aDFhZ3Q3cUxVa0FIbXVuZ1NYQWZxMUlOTkVNYXRTcFVJUURJM3huWmNPeTNhSWFP YkVpSlFHY1lrWlhXZ1Z2cVdvcktPOW53a29Hem5BSm1HRVZHYU11dDYwaGg2SGU1MVJPTll3WHc9 on: all_branches: false tags: true script: - source activate test-environment - python setup.py install - python -m unittest discover -s tests -v ================================================ FILE: AUTHORS.rst ================================================ ======= Credits ======= Development Lead ---------------- * `Nasim Rahaman `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , Contributors ------------ In no particular order, * `Steffen Wolf `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , * `Maurice Weiler `_ @ `Amsterdam Machine Learning Lab `_ , `University of Amsterdam `_ , * `Constantin Pape `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , * `Sven Peter `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , * `Manuel Haussmann `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , * `Thorsten Beier `_ @ `Image Analysis and Learning Lab `_ , `Heidelberg Collaboratory for Image Processing `_ , * `Benjamin Striner `_ @ `Machine Learning Department `_ , `Carnegie Mellon University `_ , ================================================ FILE: CONTRIBUTING.rst ================================================ .. highlight:: shell ============ Contributing ============ Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. You can contribute in many ways: Types of Contributions ---------------------- Report Bugs ~~~~~~~~~~~ Report bugs at https://github.com/nasimrahaman/inferno/issues. If you are reporting a bug, please include: * Your operating system name and version. * Any details about your local setup that might be helpful in troubleshooting. * Detailed steps to reproduce the bug. Fix Bugs ~~~~~~~~ Look through the GitHub issues for bugs. Anything tagged with "bug" and "help wanted" is open to whoever wants to implement it. Implement Features ~~~~~~~~~~~~~~~~~~ Look through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is open to whoever wants to implement it. Write Documentation ~~~~~~~~~~~~~~~~~~~ inferno could always use more documentation, whether as part of the official inferno docs, in docstrings, or even on the web in blog posts, articles, and such. Submit Feedback ~~~~~~~~~~~~~~~ The best way to send feedback is to file an issue at https://github.com/nasimrahaman/inferno/issues. If you are proposing a feature: * Explain in detail how it would work. * Keep the scope as narrow as possible, to make it easier to implement. * Remember that this is a volunteer-driven project, and that contributions are welcome :) Get Started! ------------ Ready to contribute? Here's how to set up `inferno` for local development. 1. Fork the `inferno` repo on GitHub. 2. Clone your fork locally:: $ git clone git@github.com:your_name_here/inferno.git 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: $ mkvirtualenv inferno $ cd inferno/ $ python setup.py develop 4. Create a branch for local development:: $ git checkout -b name-of-your-bugfix-or-feature Now you can make your changes locally. 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: $ flake8 inferno tests $ python setup.py test or py.test $ tox To get flake8 and tox, just pip install them into your virtualenv. 6. Commit your changes and push your branch to GitHub:: $ git add . $ git commit -m "Your detailed description of your changes." $ git push origin name-of-your-bugfix-or-feature 7. Submit a pull request through the GitHub website. Pull Request Guidelines ----------------------- Before you submit a pull request, check that it meets these guidelines: 1. The pull request should include tests. 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. 3. The pull request should work for Python 3.5 and 3.6. Check https://travis-ci.org/nasimrahaman/inferno/pull_requests and make sure that the tests pass for all supported Python versions. Tips ---- To run a subset of tests:: $ python -m unittest tests.test_inferno Sphinx Apidoc -------------- before building the documentation one needs to generate the auto-generated sphinxs api documentation. These files need to be in the github repository. .. code:: bash cd docs sphinx-apidoc -o inferno-apidoc ../inferno .. warning:: Do not make changes to `inferno/docs/inferno-apidoc` This folder is auto-generated by the above mentioned command. The following combines all the commands necessary to build the html documentation: .. code:: bash ./build_docs.sh ================================================ FILE: HISTORY.rst ================================================ ======= History ======= 0.1.0 (2017-08-24) ------------------ * First early release on PyPI 0.1.1 (2017-08-24) ------------------ * Version Increment 0.1.2 (2017-08-24) ------------------ * Version Increment 0.1.3 (2017-08-24) ------------------ * Updated Documentation 0.1.4 (2017-08-24) ------------------ * travis auto-deployment on pypi 0.1.5 (2017-08-24) ------------------ * travis changes to run unittest 0.1.6 (2017-08-24) ------------------ * travis missing packages for unittesting * fixed inconsistent version numbers 0.1.7 (2017-08-25) ------------------ * setup.py critical bugix in install procedure CURRENT CHANGES ----------------- * Flexible Unet ================================================ FILE: LICENSE ================================================ Apache Software License 2.0 Copyright (c) 2017, Inferno Developers Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ include AUTHORS.rst include CONTRIBUTING.rst include HISTORY.rst include LICENSE include README.rst recursive-include tests * recursive-exclude * __pycache__ recursive-exclude * *.py[co] recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif ================================================ FILE: Makefile ================================================ .PHONY: clean clean-test clean-pyc clean-build docs help .DEFAULT_GOAL := help define BROWSER_PYSCRIPT import os, webbrowser, sys try: from urllib import pathname2url except: from urllib.request import pathname2url webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) endef export BROWSER_PYSCRIPT define PRINT_HELP_PYSCRIPT import re, sys for line in sys.stdin: match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) if match: target, help = match.groups() print("%-20s %s" % (target, help)) endef export PRINT_HELP_PYSCRIPT BROWSER := python -c "$$BROWSER_PYSCRIPT" help: @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts clean-build: ## remove build artifacts rm -fr build/ rm -fr dist/ rm -fr .eggs/ find . -name '*.egg-info' -exec rm -fr {} + find . -name '*.egg' -exec rm -f {} + clean-pyc: ## remove Python file artifacts find . -name '*.pyc' -exec rm -f {} + find . -name '*.pyo' -exec rm -f {} + find . -name '*~' -exec rm -f {} + find . -name '__pycache__' -exec rm -fr {} + clean-test: ## remove test and coverage artifacts rm -fr .tox/ rm -f .coverage rm -fr htmlcov/ lint: ## check style with flake8 flake8 inferno tests test: ## run tests quickly with the default Python python setup.py test test-all: ## run tests on every Python version with tox tox coverage: ## check code coverage quickly with the default Python coverage run --source inferno setup.py test coverage report -m coverage html $(BROWSER) htmlcov/index.html docs: ## generate Sphinx HTML documentation, including API docs rm -f docs/inferno.rst rm -f docs/modules.rst sphinx-apidoc -o docs/ inferno $(MAKE) -C docs clean $(MAKE) -C docs html $(BROWSER) docs/_build/html/index.html servedocs: docs ## compile the docs watching for changes watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . release: clean ## package and upload a release python setup.py sdist upload python setup.py bdist_wheel upload dist: clean ## builds source and wheel package python setup.py sdist python setup.py bdist_wheel ls -l dist install: clean ## install the package to the active Python's site-packages python setup.py install ================================================ FILE: README.rst ================================================ ======= Inferno ======= .. image:: https://anaconda.org/conda-forge/inferno/badges/version.svg :target: https://anaconda.org/conda-forge/inferno .. image:: https://travis-ci.org/inferno-pytorch/inferno.svg?branch=master :target: https://travis-ci.org/inferno-pytorch/inferno .. TODO new docs shield goes here, see https://github.com/inferno-pytorch/inferno/issues/139 .. image:: https://readthedocs.org/projects/inferno-pytorch/badge/?version=latest :target: http://inferno-pytorch.readthedocs.io/en/latest/?badge=latest :alt: Documentation Status .. image:: http://svgshare.com/i/2j7.svg Inferno is a little library providing utilities and convenience functions/classes around `PyTorch `_. It's a work-in-progress, but the releases from v0.4 on should be fairly stable! * Free software: Apache Software License 2.0 * Documentation: http://inferno-pytorch.readthedocs.io (Work in Progress). Features -------- Current features include: * a basic `Trainer class `_ to encapsulate the training boilerplate (iteration/epoch loops, validation and checkpoint creation), * a `graph API `_ for building models with complex architectures, powered by `networkx `_. * `easy data-parallelism `_ over multiple GPUs, * `a submodule `_ for `torch.nn.Module`-level parameter initialization, * `a submodule `_ for data preprocessing / transforms, * `support `_ for `Tensorboard `_ (best with atleast `tensorflow-cpu `_ installed) * `a callback API `_ to enable flexible interaction with the trainer, * `various utility layers `_ with more underway, * `a submodule `_ for volumetric datasets, and more! .. code:: python import torch.nn as nn from inferno.io.box.cifar import get_cifar10_loaders from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger from inferno.extensions.layers.convolutional import ConvELU2D from inferno.extensions.layers.reshape import Flatten # Fill these in: LOG_DIRECTORY = '...' SAVE_DIRECTORY = '...' DATASET_DIRECTORY = '...' DOWNLOAD_CIFAR = True USE_CUDA = True # Build torch model model = nn.Sequential( ConvELU2D(in_channels=3, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), Flatten(), nn.Linear(in_features=(256 * 4 * 4), out_features=10), nn.LogSoftmax(dim=1) ) # Load loaders train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY, download=DOWNLOAD_CIFAR) # Build trainer trainer = Trainer(model) \ .build_criterion('NLLLoss') \ .build_metric('CategoricalError') \ .build_optimizer('Adam') \ .validate_every((2, 'epochs')) \ .save_every((5, 'epochs')) \ .save_to_directory(SAVE_DIRECTORY) \ .set_max_num_epochs(10) \ .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every='never'), log_directory=LOG_DIRECTORY) # Bind loaders trainer \ .bind_loader('train', train_loader) \ .bind_loader('validate', validate_loader) if USE_CUDA: trainer.cuda() # Go! trainer.fit() To visualize the training progress, navigate to `LOG_DIRECTORY` and fire up tensorboard with .. code:: bash $ tensorboard --logdir=${PWD} --port=6007 and navigate to `localhost:6007` with your browser. Installation ------------------------ Conda packages for python >= 3.6 for all distributions are availaible on conda-forge: .. code:: bash $ conda install -c pytorch -c conda-forge inferno Future Features: ------------------------ Planned features include: * a class to encapsulate Hogwild! training over multiple GPUs, * minimal shape inference with a dry-run, * proper packaging and documentation, * cutting-edge fresh-off-the-press implementations of what the future has in store. :) Credits --------- All contributors are listed here_. .. _here: https://inferno-pytorch.github.io/inferno/html/authors.html This package was partially generated with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template + lots of work by Thorsten. .. _Cookiecutter: https://github.com/audreyr/cookiecutter .. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage ================================================ FILE: add2path.sh ================================================ #!/usr/bin/env bash # Run this script from within the directory. export PYTHONPATH=${PYTHONPATH}:${PWD} ================================================ FILE: build_docs.sh ================================================ #!/bin/bash cd docs rm -r -f inferno-apidoc sphinx-apidoc -o inferno-apidoc ../inferno make html cd .. ================================================ FILE: conda-recipe/build.sh ================================================ PY_VER=$(python -c "import sys; print('{}.{}'.format(*sys.version_info[:2]))") # Install python modules mkdir -p ${PREFIX}/inferno cp -r inferno/* ${PREFIX}/inferno echo "${PREFIX}" > ${PREFIX}/lib/python${PY_VER}/site-packages/inferno.pth python -m compileall ${PREFIX}/inferno ================================================ FILE: conda-recipe/meta.yaml ================================================ package: name: inferno {% set tagged_version = GIT_DESCRIBE_TAG|replace("v","")|replace("-", ".") %} # If we're using a non-tagged revision, append '.postN' to the version {% if GIT_DESCRIBE_NUMBER|int != 0 %} {% set tagged_version = tagged_version + '.post' + GIT_DESCRIBE_NUMBER %} {% endif %} version: {{tagged_version}} source: path: .. build: number: 1 string: py_{{PKG_BUILDNUM}}_g{{GIT_FULL_HASH[:7]}} requirements: build: - python {{PY_VER}}* run: - python {{PY_VER}}* - pytorch - torchvision - pyyaml - scipy - scikit-image - scikit-learn - h5py - dill - networkx 1.11 - tensorboardx - sphinx_rtd_theme test: imports: - inferno about: license: Apache License 2.0 summary: A utility library around PyTorch ================================================ FILE: docs/.gitignore ================================================ /inferno.rst /inferno.*.rst /modules.rst ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/inferno.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/inferno.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/inferno" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/inferno" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." ================================================ FILE: docs/_templates/layout.html ================================================ {# layout.html #} {# Import the theme's layout. #} {% extends "!layout.html" %} {% set css_files = css_files + ['_static/pygments.css'] %} ================================================ FILE: docs/_templates/template_module.rst ================================================ {{ fullname }} {{ underline }} .. automodule:: {{ fullname }} {% block functions %} {% if functions %} Functions ================== {% for item in functions %} .. autofunction:: {{ item }} .. include:: backreferences/{{fullname}}.{{item}}.examples .. raw:: html
{%- endfor %} {% endif %} {% endblock %} {% block classes %} {% if classes %} Classes ------- {% for item in classes %} .. autoclass:: {{ item }} :members: .. include:: backreferences/{{fullname}}.{{item}}.examples .. raw:: html
{%- endfor %} {% endif %} {% endblock %} {% block exceptions %} {% if exceptions %} Exceptions ---------- .. autosummary:: {% for item in exceptions %} {{ item }} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/authors.rst ================================================ .. include:: ../AUTHORS.rst ================================================ FILE: docs/conf.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- # # inferno documentation build configuration file, created by # sphinx-quickstart on Tue Jul 9 22:26:36 2013. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import matplotlib matplotlib.use('Agg') import sphinx_gallery import sys from unittest.mock import MagicMock class Mock(MagicMock): @classmethod def __getattr__(cls, name): return MagicMock() # MOCK_MODULES = ['pygtk', # 'hdf5', # 'skimage', # 'argparse', # 'pandas', # 'torch', # 'torch.nn', 'torch.nn.init', 'torch.nn.functional', # 'torch.nn.parallel', 'torch.nn.parallel.data_parallel', # 'torch.multiprocessing', 'torch.autograd', # 'torch.utils', 'torch.utils.data', # 'torch.optim', 'torch.sparse', 'torch.cuda'] # sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) import os # If extensions (or modules to document with autodoc) are in another # directory, add these directories to sys.path here. If the directory is # relative to the documentation root, use os.path.abspath to make it # absolute, like shown here. #sys.path.insert(0, os.path.abspath('.')) # Get the project root dir, which is the parent dir of this cwd = os.getcwd() project_root = os.path.dirname(cwd) # Insert the project root dir as the first element in the PYTHONPATH. # This lets us ensure that the source package is imported, and that its # version is used. sys.path.insert(0, project_root) import inferno import inferno.extensions import inferno.extensions.layers from inferno.extensions.layers import * from inferno.extensions.layers.reshape import * # -- General configuration --------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.autosummary', 'sphinx.ext.doctest', 'sphinx.ext.todo', 'sphinx.ext.ifconfig', 'sphinx.ext.mathjax', 'sphinx.ext.graphviz', 'sphinx_gallery.gen_gallery', 'sphinxcontrib.bibtex', 'sphinx.ext.napoleon', 'sphinxcontrib.inlinesyntaxhighlight' ] sphinx_gallery_conf = { # path to your examples scripts 'examples_dirs' : '../examples', # path where to save gallery generated examples 'gallery_dirs' : 'auto_examples', 'backreferences_dir' : 'gen_modules/backreferences', 'scan_used_functions': True, 'doc_module' : ('inferno','inferno.extensions','inferno.extensions.layers','inferno.extensions.layers.convolutional'), 'docs_resolv': True, 'parallel_read_safe': True, 'reference_url': { # The module you locally document uses a None 'inferno': None, # External python modules use their documentation websites #'matplotlib': 'http://matplotlib.org', 'numpy': 'http://docs.scipy.org/doc/numpy-1.13.0'} } # Napoleon settings napoleon_google_docstring = True napoleon_numpy_docstring = True napoleon_include_init_with_doc = False napoleon_include_private_with_doc = False napoleon_include_special_with_doc = True napoleon_use_admonition_for_examples = False napoleon_use_admonition_for_notes = False napoleon_use_admonition_for_references = False napoleon_use_ivar = False napoleon_use_param = True napoleon_use_rtype = True # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # generate autosummary even if no references autosummary_generate = True # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = u'inferno' copyright = u"2018, f" # The version info for the project you're documenting, acts as replacement # for |version| and |release|, also used in various other places throughout # the built documents. # # The short X.Y version. version = inferno.__version__ # The full version, including alpha/beta/rc tags. release = inferno.__version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. #language = None # There are two options for replacing |today|: either, you set today to # some non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['_build'] # The reST default role (used for this markup: `text`) to use for all # documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built # documents. #keep_warnings = False # -- Options for HTML output ------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the # documentation. #html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as # html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the # top of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon # of the docs. This file should be a Windows icon file (.ico) being # 16x16 or 32x32 pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) # here, relative to this directory. They are copied after the builtin # static files, so a file named "default.css" will overwrite the builtin # "default.css". html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page # bottom, using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names # to template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. # Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. # Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages # will contain a tag referring to it. The value of this option # must be the base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = 'infernodoc' # -- Options for LaTeX output ------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass # [howto/manual]). latex_documents = [ ('index', 'inferno.tex', u'inferno Documentation', u'Inferno Team', 'manual'), ] # The name of an image file (relative to this directory) to place at # the top of the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings # are parts, not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output ------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'inferno', u'inferno Documentation', [u'Inferno Team'], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ---------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ('index', 'inferno', u'inferno Documentation', u'Inferno Team', 'inferno', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False ================================================ FILE: docs/contributing.rst ================================================ .. include:: ../CONTRIBUTING.rst ================================================ FILE: docs/environment.yml ================================================ name: inferno_docs channels: - soumith - anaconda dependencies: - python==3.5 - pytorch>=0.1.12 - torchvision - scikit-image - pip: - scipy>=0.13.0 - h5py - scikit-image - pyyaml - dill - sphinx-gallery - sphinxcontrib-napoleon - sphinxcontrib-bibtex - sphinxcontrib-inlinesyntaxhighlight ================================================ FILE: docs/examples.rst ================================================ .. _inferno_examples_gallery: Inferno Examples Gallery ============================ .. toctree:: :maxdepth: 5 ../auto_examples/index ================================================ FILE: docs/history.rst ================================================ .. include:: ../HISTORY.rst ================================================ FILE: docs/index.rst ================================================ Welcome to inferno's documentation! ====================================== Contents: .. toctree:: :maxdepth: 1 readme installation usage examples contributing inferno-apidoc/modules authors history zbibliography .. automodule:: inferno Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/inferno-apidoc/inferno.extensions.containers.rst ================================================ inferno.extensions.containers package ===================================== Submodules ---------- inferno.extensions.containers.graph module ------------------------------------------ .. automodule:: inferno.extensions.containers.graph :members: :undoc-members: :show-inheritance: inferno.extensions.containers.sequential module ----------------------------------------------- .. automodule:: inferno.extensions.containers.sequential :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.extensions.containers :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.extensions.criteria.rst ================================================ inferno.extensions.criteria package =================================== Submodules ---------- inferno.extensions.criteria.core module --------------------------------------- .. automodule:: inferno.extensions.criteria.core :members: :undoc-members: :show-inheritance: inferno.extensions.criteria.elementwise\_measures module -------------------------------------------------------- .. automodule:: inferno.extensions.criteria.elementwise_measures :members: :undoc-members: :show-inheritance: inferno.extensions.criteria.regularized module ---------------------------------------------- .. automodule:: inferno.extensions.criteria.regularized :members: :undoc-members: :show-inheritance: inferno.extensions.criteria.set\_similarity\_measures module ------------------------------------------------------------ .. automodule:: inferno.extensions.criteria.set_similarity_measures :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.extensions.criteria :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.extensions.initializers.rst ================================================ inferno.extensions.initializers package ======================================= Submodules ---------- inferno.extensions.initializers.base module ------------------------------------------- .. automodule:: inferno.extensions.initializers.base :members: :undoc-members: :show-inheritance: inferno.extensions.initializers.presets module ---------------------------------------------- .. automodule:: inferno.extensions.initializers.presets :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.extensions.initializers :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.extensions.layers.rst ================================================ inferno.extensions.layers package ================================= Submodules ---------- inferno.extensions.layers.activations module -------------------------------------------- .. automodule:: inferno.extensions.layers.activations :members: :undoc-members: :show-inheritance: inferno.extensions.layers.building\_blocks module ------------------------------------------------- .. automodule:: inferno.extensions.layers.building_blocks :members: :undoc-members: :show-inheritance: inferno.extensions.layers.convolutional module ---------------------------------------------- .. automodule:: inferno.extensions.layers.convolutional :members: :undoc-members: :show-inheritance: inferno.extensions.layers.device module --------------------------------------- .. automodule:: inferno.extensions.layers.device :members: :undoc-members: :show-inheritance: inferno.extensions.layers.identity module ----------------------------------------- .. automodule:: inferno.extensions.layers.identity :members: :undoc-members: :show-inheritance: inferno.extensions.layers.prefab module --------------------------------------- .. automodule:: inferno.extensions.layers.prefab :members: :undoc-members: :show-inheritance: inferno.extensions.layers.res\_unet module ------------------------------------------ .. automodule:: inferno.extensions.layers.res_unet :members: :undoc-members: :show-inheritance: inferno.extensions.layers.reshape module ---------------------------------------- .. automodule:: inferno.extensions.layers.reshape :members: :undoc-members: :show-inheritance: inferno.extensions.layers.sampling module ----------------------------------------- .. automodule:: inferno.extensions.layers.sampling :members: :undoc-members: :show-inheritance: inferno.extensions.layers.unet\_base module ------------------------------------------- .. automodule:: inferno.extensions.layers.unet_base :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.extensions.layers :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.extensions.metrics.rst ================================================ inferno.extensions.metrics package ================================== Submodules ---------- inferno.extensions.metrics.arand module --------------------------------------- .. automodule:: inferno.extensions.metrics.arand :members: :undoc-members: :show-inheritance: inferno.extensions.metrics.base module -------------------------------------- .. automodule:: inferno.extensions.metrics.base :members: :undoc-members: :show-inheritance: inferno.extensions.metrics.categorical module --------------------------------------------- .. automodule:: inferno.extensions.metrics.categorical :members: :undoc-members: :show-inheritance: inferno.extensions.metrics.cremi\_score module ---------------------------------------------- .. automodule:: inferno.extensions.metrics.cremi_score :members: :undoc-members: :show-inheritance: inferno.extensions.metrics.voi module ------------------------------------- .. automodule:: inferno.extensions.metrics.voi :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.extensions.metrics :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.extensions.optimizers.rst ================================================ inferno.extensions.optimizers package ===================================== Submodules ---------- inferno.extensions.optimizers.adam module ----------------------------------------- .. automodule:: inferno.extensions.optimizers.adam :members: :undoc-members: :show-inheritance: inferno.extensions.optimizers.annealed\_adam module --------------------------------------------------- .. automodule:: inferno.extensions.optimizers.annealed_adam :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.extensions.optimizers :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.extensions.rst ================================================ inferno.extensions package ========================== Subpackages ----------- .. toctree:: inferno.extensions.containers inferno.extensions.criteria inferno.extensions.initializers inferno.extensions.layers inferno.extensions.metrics inferno.extensions.optimizers Module contents --------------- .. automodule:: inferno.extensions :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.io.box.rst ================================================ inferno.io.box package ====================== Submodules ---------- inferno.io.box.binary\_blobs module ----------------------------------- .. automodule:: inferno.io.box.binary_blobs :members: :undoc-members: :show-inheritance: inferno.io.box.camvid module ---------------------------- .. automodule:: inferno.io.box.camvid :members: :undoc-members: :show-inheritance: inferno.io.box.cifar module --------------------------- .. automodule:: inferno.io.box.cifar :members: :undoc-members: :show-inheritance: inferno.io.box.cityscapes module -------------------------------- .. automodule:: inferno.io.box.cityscapes :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.io.box :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.io.core.rst ================================================ inferno.io.core package ======================= Submodules ---------- inferno.io.core.base module --------------------------- .. automodule:: inferno.io.core.base :members: :undoc-members: :show-inheritance: inferno.io.core.concatenate module ---------------------------------- .. automodule:: inferno.io.core.concatenate :members: :undoc-members: :show-inheritance: inferno.io.core.data\_utils module ---------------------------------- .. automodule:: inferno.io.core.data_utils :members: :undoc-members: :show-inheritance: inferno.io.core.zip module -------------------------- .. automodule:: inferno.io.core.zip :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.io.core :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.io.rst ================================================ inferno.io package ================== Subpackages ----------- .. toctree:: inferno.io.box inferno.io.core inferno.io.transform inferno.io.volumetric Module contents --------------- .. automodule:: inferno.io :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.io.transform.rst ================================================ inferno.io.transform package ============================ Submodules ---------- inferno.io.transform.base module -------------------------------- .. automodule:: inferno.io.transform.base :members: :undoc-members: :show-inheritance: inferno.io.transform.generic module ----------------------------------- .. automodule:: inferno.io.transform.generic :members: :undoc-members: :show-inheritance: inferno.io.transform.image module --------------------------------- .. automodule:: inferno.io.transform.image :members: :undoc-members: :show-inheritance: inferno.io.transform.volume module ---------------------------------- .. automodule:: inferno.io.transform.volume :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.io.transform :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.io.volumetric.rst ================================================ inferno.io.volumetric package ============================= Submodules ---------- inferno.io.volumetric.lazy\_volume\_loader module ------------------------------------------------- .. automodule:: inferno.io.volumetric.lazy_volume_loader :members: :undoc-members: :show-inheritance: inferno.io.volumetric.volume module ----------------------------------- .. automodule:: inferno.io.volumetric.volume :members: :undoc-members: :show-inheritance: inferno.io.volumetric.volumetric\_utils module ---------------------------------------------- .. automodule:: inferno.io.volumetric.volumetric_utils :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.io.volumetric :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.rst ================================================ inferno package =============== Subpackages ----------- .. toctree:: inferno.extensions inferno.io inferno.trainers inferno.utils Submodules ---------- inferno.inferno module ---------------------- .. automodule:: inferno.inferno :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.trainers.callbacks.logging.rst ================================================ inferno.trainers.callbacks.logging package ========================================== Submodules ---------- inferno.trainers.callbacks.logging.base module ---------------------------------------------- .. automodule:: inferno.trainers.callbacks.logging.base :members: :undoc-members: :show-inheritance: inferno.trainers.callbacks.logging.tensorboard module ----------------------------------------------------- .. automodule:: inferno.trainers.callbacks.logging.tensorboard :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.trainers.callbacks.logging :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.trainers.callbacks.rst ================================================ inferno.trainers.callbacks package ================================== Subpackages ----------- .. toctree:: inferno.trainers.callbacks.logging Submodules ---------- inferno.trainers.callbacks.base module -------------------------------------- .. automodule:: inferno.trainers.callbacks.base :members: :undoc-members: :show-inheritance: inferno.trainers.callbacks.console module ----------------------------------------- .. automodule:: inferno.trainers.callbacks.console :members: :undoc-members: :show-inheritance: inferno.trainers.callbacks.essentials module -------------------------------------------- .. automodule:: inferno.trainers.callbacks.essentials :members: :undoc-members: :show-inheritance: inferno.trainers.callbacks.scheduling module -------------------------------------------- .. automodule:: inferno.trainers.callbacks.scheduling :members: :undoc-members: :show-inheritance: inferno.trainers.callbacks.tqdm module -------------------------------------- .. automodule:: inferno.trainers.callbacks.tqdm :members: :undoc-members: :show-inheritance: inferno.trainers.callbacks.tqdmstub module ------------------------------------------ .. automodule:: inferno.trainers.callbacks.tqdmstub :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.trainers.callbacks :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.trainers.rst ================================================ inferno.trainers package ======================== Subpackages ----------- .. toctree:: inferno.trainers.callbacks Submodules ---------- inferno.trainers.basic module ----------------------------- .. automodule:: inferno.trainers.basic :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.trainers :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/inferno.utils.rst ================================================ inferno.utils package ===================== Submodules ---------- inferno.utils.exceptions module ------------------------------- .. automodule:: inferno.utils.exceptions :members: :undoc-members: :show-inheritance: inferno.utils.io\_utils module ------------------------------ .. automodule:: inferno.utils.io_utils :members: :undoc-members: :show-inheritance: inferno.utils.math\_utils module -------------------------------- .. automodule:: inferno.utils.math_utils :members: :undoc-members: :show-inheritance: inferno.utils.model\_utils module --------------------------------- .. automodule:: inferno.utils.model_utils :members: :undoc-members: :show-inheritance: inferno.utils.python\_utils module ---------------------------------- .. automodule:: inferno.utils.python_utils :members: :undoc-members: :show-inheritance: inferno.utils.test\_utils module -------------------------------- .. automodule:: inferno.utils.test_utils :members: :undoc-members: :show-inheritance: inferno.utils.torch\_utils module --------------------------------- .. automodule:: inferno.utils.torch_utils :members: :undoc-members: :show-inheritance: inferno.utils.train\_utils module --------------------------------- .. automodule:: inferno.utils.train_utils :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: inferno.utils :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/inferno-apidoc/modules.rst ================================================ inferno ======= .. toctree:: :maxdepth: 4 inferno ================================================ FILE: docs/installation.rst ================================================ .. highlight:: shell ================================== Installation ================================== Install on Linux and OSX ------------------------ Developers ~~~~~~~~~~~~~~~~~~~~~~ First, make sure `you have Pytorch installed `_. Then, clone this repository with: .. code:: python $ git clone https://github.com/nasimrahaman/inferno.git Next, install the dependencies. .. code:: python $ cd inferno $ pip install -r requirements.txt If you use python from the shell: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Finally, add *inferno* to your `PYTHONPATH` with: .. code:: python source add2path.sh If you use PyCharm: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Refer to this `QA `_ about setting up paths with Pycharm. ====================================================== Installation via PyPi / pip / setup.py(Experimental) ====================================================== You need to install pytorch via pip before installing inferno. Follow the `pytorch installation guide`_. Stable release -------------- To install inferno, run this command in your terminal: .. code-block:: console $ pip install inferno-pytorch This is the preferred method to install inferno, as it will always install the most recent stable release. If you don't have `pip`_ installed, this `Python installation guide`_ can guide you through the process. .. _pip: https://pip.pypa.io .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ .. _pytorch installation guide: http://pytorch.org/ From sources ------------------------ First, make sure `you have Pytorch installed `_. The sources for inferno can be downloaded from the `Github repo`_. You can either clone the public repository: .. code-block:: console $ git clone git://github.com/nasimrahaman/inferno Or download the `tarball`_: .. code-block:: console $ curl -OL https://github.com/nasimrahaman/inferno/tarball/master Once you have a copy of the source, you can install it with: .. code-block:: console $ python setup.py install .. _Github repo: https://github.com/nasimrahaman/inferno .. _tarball: https://github.com/nasimrahaman/inferno/tarball/master ================================================ FILE: docs/make.bat ================================================ @ECHO OFF REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set BUILDDIR=_build set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . set I18NSPHINXOPTS=%SPHINXOPTS% . if NOT "%PAPER%" == "" ( set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% ) if "%1" == "" goto help if "%1" == "help" ( :help echo.Please use `make ^` where ^ is one of echo. html to make standalone HTML files echo. dirhtml to make HTML files named index.html in directories echo. singlehtml to make a single large HTML file echo. pickle to make pickle files echo. json to make JSON files echo. htmlhelp to make HTML files and a HTML help project echo. qthelp to make HTML files and a qthelp project echo. devhelp to make HTML files and a Devhelp project echo. epub to make an epub echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. text to make text files echo. man to make manual pages echo. texinfo to make Texinfo files echo. gettext to make PO message catalogs echo. changes to make an overview over all changed/added/deprecated items echo. xml to make Docutils-native XML files echo. pseudoxml to make pseudoxml-XML files for display purposes echo. linkcheck to check all external links for integrity echo. doctest to run all doctests embedded in the documentation if enabled goto end ) if "%1" == "clean" ( for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i del /q /s %BUILDDIR%\* goto end ) %SPHINXBUILD% 2> nul if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) if "%1" == "html" ( %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/html. goto end ) if "%1" == "dirhtml" ( %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. goto end ) if "%1" == "singlehtml" ( %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. goto end ) if "%1" == "pickle" ( %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the pickle files. goto end ) if "%1" == "json" ( %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the JSON files. goto end ) if "%1" == "htmlhelp" ( %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run HTML Help Workshop with the ^ .hhp project file in %BUILDDIR%/htmlhelp. goto end ) if "%1" == "qthelp" ( %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run "qcollectiongenerator" with the ^ .qhcp project file in %BUILDDIR%/qthelp, like this: echo.^> qcollectiongenerator %BUILDDIR%\qthelp\inferno.qhcp echo.To view the help file: echo.^> assistant -collectionFile %BUILDDIR%\qthelp\inferno.ghc goto end ) if "%1" == "devhelp" ( %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp if errorlevel 1 exit /b 1 echo. echo.Build finished. goto end ) if "%1" == "epub" ( %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub if errorlevel 1 exit /b 1 echo. echo.Build finished. The epub file is in %BUILDDIR%/epub. goto end ) if "%1" == "latex" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex if errorlevel 1 exit /b 1 echo. echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdf" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf cd %BUILDDIR%/.. echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdfja" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf-ja cd %BUILDDIR%/.. echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "text" ( %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text if errorlevel 1 exit /b 1 echo. echo.Build finished. The text files are in %BUILDDIR%/text. goto end ) if "%1" == "man" ( %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man if errorlevel 1 exit /b 1 echo. echo.Build finished. The manual pages are in %BUILDDIR%/man. goto end ) if "%1" == "texinfo" ( %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo if errorlevel 1 exit /b 1 echo. echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. goto end ) if "%1" == "gettext" ( %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale if errorlevel 1 exit /b 1 echo. echo.Build finished. The message catalogs are in %BUILDDIR%/locale. goto end ) if "%1" == "changes" ( %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes if errorlevel 1 exit /b 1 echo. echo.The overview file is in %BUILDDIR%/changes. goto end ) if "%1" == "linkcheck" ( %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck if errorlevel 1 exit /b 1 echo. echo.Link check complete; look for any errors in the above output ^ or in %BUILDDIR%/linkcheck/output.txt. goto end ) if "%1" == "doctest" ( %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest if errorlevel 1 exit /b 1 echo. echo.Testing of doctests in the sources finished, look at the ^ results in %BUILDDIR%/doctest/output.txt. goto end ) if "%1" == "xml" ( %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml if errorlevel 1 exit /b 1 echo. echo.Build finished. The XML files are in %BUILDDIR%/xml. goto end ) if "%1" == "pseudoxml" ( %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml if errorlevel 1 exit /b 1 echo. echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. goto end ) :end ================================================ FILE: docs/readme.rst ================================================ .. include:: ../README.rst ================================================ FILE: docs/refs.bib ================================================ @inproceedings{alush_2013_simbad, title={Break and Conquer: Efficient Correlation Clustering for Image Segmentation}, author={Alush, Amir and Goldberger, Jacob}, booktitle={2nd International Workshop on Similarity-Based Pattern Analysis and Recognition}, year={2013} } ================================================ FILE: docs/usage.rst ================================================ ===== Usage ===== Inferno is a utility library built around [PyTorch](http://pytorch.org/), designed to help you train and even build complex pytorch models. And in this tutorial, we'll see how! If you're new to PyTorch, I highly recommended you work through the [Pytorch tutorials](http://pytorch.org/tutorials/) first. Building a PyTorch Model ~~~~~~~~~~~~~~~~~~~~~~~~~~ Inferno's training machinery works with just about any valid [Pytorch module](http://pytorch.org/docs/master/nn.html#torch.nn.Module). However, to make things even easier, we also provide pre-configured layers that work out-of-the-box. Let's use them to build a convolutional neural network for Cifar-10. .. code:: python import torch.nn as nn from inferno.extensions.layers.convolutional import ConvELU2D from inferno.extensions.layers.reshape import Flatten `ConvELU2D` is a 2-dimensional convolutional layer with orthogonal weight initialization and [ELU](http://pytorch.org/docs/master/nn.html#torch.nn.ELU) activation. `Flatten` reshapes the 4 dimensional activation tensor to a matrix. Let's use the Sequential container to chain together a bunch of convolutional and pooling layers, followed by a linear and softmax layer. .. code:: python model = nn.Sequential( ConvELU2D(in_channels=3, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), Flatten(), nn.Linear(in_features=(256 * 4 * 4), out_features=10), nn.Softmax() ) Models this size don't win competitions anymore, but it'll do for our purpose. Data Logistics ************************** With our model built, it's time to worry about the data generators. Or is it? .. code:: python from inferno.io.box.cifar import get_cifar10_loaders train_loader, validate_loader = get_cifar10_loaders('path/to/cifar10', download=True, train_batch_size=128, test_batch_size=100) CIFAR-10 works out-of-the-`box` (pun very much intended) with all the fancy data-augmentation and normalization. Of course, it's perfectly fine if you have your own [`DataLoader`](http://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader). Preparing the Trainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ With our model and data loaders good to go, it's finally time to build the trainer. To start, let's initialize one. .. code:: python from inferno.trainers.basic import Trainer trainer = Trainer(model) # Tell trainer about the data loaders trainer.bind_loader('train', train_loader).bind_loader('validate', validate_loader) Now to the things we could do with it. Setting up Checkpointing *************************************** When training a model for days, it's usually a good idea to store the current training state to disk every once in a while. To set this up, we tell `trainer` where to store these *checkpoints* and how often. .. code:: python trainer.save_to_directory('path/to/save/directory').save_every((25, 'epochs')) So we're saving once every 25 epochs. But what if an epoch takes forever, and you don't wish to wait that long? .. code:: python trainer.save_every((1000, 'iterations')) In this setting, you're saving once every 1000 iterations (= batches). But we might also want to create a checkpoint when the validation score is the best. Easy as 1, 2, .. code:: python trainer.save_at_best_validation_score() Remember that a checkpoint contains the entire training state, and not just the model. Everything is included in the checkpoint file, including optimizer, criterion, and callbacks but __not the data loaders__. Setting up Validation ************************** Let's say you wish to validate once every 2 epochs. .. code:: python trainer.validate_every((2, 'epochs')) To be able to validate, you'll need to specify a validation metric. .. code:: python trainer.build_metric('CategoricalError') Inferno looks for a metric `'CategoricalError'` in `inferno.extensions.metrics`. To specify your own metric, subclass `inferno.extensions.metrics.base.Metric` and implement the `forward` method. With that done, you could: .. code:: python trainer.build_metric(MyMetric) or .. code:: python trainer.build_metric(MyMetric, **my_metric_kwargs) A metric might be way too expensive to evaluate every training iteration without slowing down the training. If this is the case and you'd like to evaluate the metric every (say) 10 *training* iterations: .. code:: python trainer.evaluate_metric_every((10, 'iterations')) However, while validating, the metric is evaluated once every iteration. Setting up the Criterion and Optimizer *************************************** With that out of the way, let's set up a training criterion and an optimizer. .. code:: python # set up the criterion trainer.build_criterion('CrossEntropyLoss') The `trainer` looks for a `'CrossEntropyLoss'` in `torch.nn`, which it finds. But any of the following would have worked: .. code:: python trainer.build_criterion(nn.CrossEntropyLoss) or .. code:: python trainer.build_criterion(nn.CrossEntropyLoss()) What this means is that if you have your own loss criterion that has the same API as any of the criteria found in `torch.nn`, you should be fine by just plugging it in. The same holds for the optimizer: .. code:: python trainer.build_optimizer('Adam', weight_decay=0.0005) Like for criteria, the `trainer` looks for a `'Adam'` in `torch.optim` (among other places), and initializes it with `model`'s parameters. Any keywords you might use for `torch.optim.Adam`, you could pass them to the `build_optimizer` method. Or alternatively, you could use: .. code:: python from torch.optim import Adam trainer.build_optimizer(Adam, weight_decay=0.0005) If you implemented your own optimizer (by subclassing `torch.optim.Optimizer`), you should be able to use it instead of `Adam`. Alternatively, if you already have an optimizer *instance*, you could do: .. code:: python optimizer = MyOptimizer(model.parameters(), **optimizer_kwargs) trainer.build_optimizer(optimizer) Setting up Training Duration ******************************** You probably don't want to train forever, in which case you must specify: .. code:: python trainer.set_max_num_epochs(100) or .. code:: python trainer.set_max_num_iterations(10000) If you like to train indefinitely (or until you're happy with the results), use: .. code:: python trainer.set_max_num_iterations('inf') In this case, you'll need to interrupt the training manually with a `KeyboardInterrupt`. Setting up Callbacks ********************* Callbacks are pretty handy when it comes to interacting with the `Trainer`. More precisely: `Trainer` defines a number of events as 'triggers' for callbacks. Currently, these are: .. code:: python BEGIN_OF_FIT, END_OF_FIT, BEGIN_OF_TRAINING_RUN, END_OF_TRAINING_RUN, BEGIN_OF_EPOCH, END_OF_EPOCH, BEGIN_OF_TRAINING_ITERATION, END_OF_TRAINING_ITERATION, BEGIN_OF_VALIDATION_RUN, END_OF_VALIDATION_RUN, BEGIN_OF_VALIDATION_ITERATION, END_OF_VALIDATION_ITERATION, BEGIN_OF_SAVE, END_OF_SAVE As an example, let's build a simple callback to interrupt the training on NaNs. We check at the end of every training iteration whether the training loss is NaN, and accordingly raise a `RuntimeError`. .. code:: python import numpy as np from inferno.trainers.callbacks.base import Callback class NaNDetector(Callback): def end_of_training_iteration(self, **_): # The callback object has the trainer as an attribute. # The trainer populates its 'states' with torch tensors (NOT VARIABLES!) training_loss = self.trainer.get_state('training_loss') # Extract float from torch tensor training_loss = training_loss[0] if np.isnan(training_loss): raise RuntimeError("NaNs detected!") With the callback defined, all we need to do is register it with the trainer: .. code:: python trainer.register_callback(NaNDetector()) So the next time you get `RuntimeError: "NaNs detected!`, you know the drill. Using Tensorboard ************************** Inferno supports logging scalars and images to Tensorboard out-of-the-box, though this requires you have at least [tensorflow-cpu](https://github.com/tensorflow/tensorflow) installed. Let's say you want to log scalars every iteration and images every 20 iterations: .. code:: python from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every=(20, 'iterations')), log_directory='/path/to/log/directory') After you've started training, use a bash shell to fire up tensorboard with: .. code:: bash $ tensorboard --logdir=/path/to/log/directory --port=6007 and navigate to `localhost:6007` with your favorite browser. Fine print: missing the `log_images_every` keyword argument to `TensorboardLogger` will result in images being logged every iteration. If you don't have a fast hard drive, this might actually slow down the training. To not log images, just use `log_images_every='never'`. Using GPUs ************* To use just one GPU: .. code:: python trainer.cuda() For multi-GPU data-parallel training, simply pass `trainer.cuda` a list of devices: .. code:: python trainer.cuda(devices=[0, 1, 2, 3]) __Pro-tip__: Say you only want to use GPUs 0, 3, 5 and 7 (your colleagues might love you for this). Before running your training script, simply: .. code:: bash $ export CUDA_VISIBLE_DEVICES=0,3,5,7 $ python train.py This maps device 0 to 0, 3 to 1, 5 to 2 and 7 to 3. One more thing ************************** Once you have everything configured, use .. code:: python trainer.fit() to commence training! This last step is kinda important. :wink: Cherries: ~~~~~~~~~~~~~~~~~~~~~~ Building Complex Models with the Graph API **************************************************** Work in Progress: Parameter Initialization ************************** Work in Progress: Support ************* Work in Progress: ================================================ FILE: docs/zbibliography.rst ================================================ .. _inferno_bibliography: Bibliography ============================ The bibliography: .. bibliography:: refs.bib :style: alpha ================================================ FILE: examples/README.txt ================================================ .. _examples-index: Gallery of Examples =================== ================================================ FILE: examples/plot_cheap_unet.py ================================================ """ UNet Tutorial ================================ A unet example which can be run without a gpu """ ############################################################################## # Preface # -------------- # We start with some unspectacular multi purpose imports needed for this example import matplotlib.pyplot as plt import torch from torch import nn import numpy ############################################################################## # determine whether we have a gpu # and should use cuda USE_CUDA = torch.cuda.is_available() ############################################################################## # Dataset # -------------- # For simplicity we will use a toy dataset where we need to perform # a binary segmentation task. from inferno.io.box.binary_blobs import get_binary_blob_loaders # convert labels from long to float as needed by # binary cross entropy loss def label_transform(x): return torch.from_numpy(x).float() #label_transform = lambda x : torch.from_numpy(x).float() train_loader, test_loader, validate_loader = get_binary_blob_loaders( size=8, # how many images per {train,test,validate} train_batch_size=2, length=256, # <= size of the images gaussian_noise_sigma=1.4, # <= how noise are the images train_label_transform = label_transform, validate_label_transform = label_transform ) image_channels = 1 # <-- number of channels of the image pred_channels = 1 # <-- number of channels needed for the prediction if False: ############################################################################## # Visualize Dataset # ~~~~~~~~~~~~~~~~~~~~~~ fig = plt.figure() for i,(image, target) in enumerate(train_loader): ax = fig.add_subplot(1, 2, 1) ax.imshow(image[0,0,...]) ax.set_title('raw data') ax = fig.add_subplot(1, 2, 2) ax.imshow(target[0,...]) ax.set_title('ground truth') break fig.tight_layout() plt.show() ############################################################################## # Training # ---------------------------- # To train the unet, we use the infernos Trainer class of inferno. # Since we train many models later on in this example we encapsulate # the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for # an example dedicated to the trainer itself). from inferno.trainers import Trainer from inferno.utils.python_utils import ensure_dir def train_model(model, loaders, **kwargs): trainer = Trainer(model) trainer.build_criterion('BCEWithLogitsLoss') trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001)) #trainer.validate_every((kwargs.get('validate_every', 10), 'epochs')) #trainer.save_every((kwargs.get('save_every', 10), 'epochs')) #trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor'))) trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20)) # bind the loaders trainer.bind_loader('train', loaders[0]) trainer.bind_loader('validate', loaders[1]) if USE_CUDA: trainer.cuda() # do the training trainer.fit() return trainer ############################################################################## # Prediction # ---------------------------- # The trainer contains the trained model and we can do predictions. # We use :code:`unwrap` to convert the results to numpy arrays. # Since we want to do many prediction we encapsulate the # the prediction in a function from inferno.utils.torch_utils import unwrap def predict(trainer, test_loader, save_dir=None): trainer.eval_mode() for image, target in test_loader: # transfer image to gpu image = image.cuda() if USE_CUDA else image # get batch size from image batch_size = image.size()[0] for b in range(batch_size): prediction = trainer.apply_model(image) prediction = torch.nn.functional.sigmoid(prediction) image = unwrap(image, as_numpy=True, to_cpu=True) prediction = unwrap(prediction, as_numpy=True, to_cpu=True) target = unwrap(target, as_numpy=True, to_cpu=True) fig = plt.figure() ax = fig.add_subplot(2, 2, 1) ax.imshow(image[b,0,...]) ax.set_title('raw data') ax = fig.add_subplot(2, 2, 2) ax.imshow(target[b,...]) ax.set_title('ground truth') ax = fig.add_subplot(2, 2, 4) ax.imshow(prediction[b,...]) ax.set_title('prediction') fig.tight_layout() plt.show() ############################################################################## # Custom UNet # ---------------------------- # Often one needs to have a UNet with custom layers. # Here we show how to implement such a customized UNet. # To this end we derive from :code:`UNetBase`. # For the sake of this example we will create # a Unet which uses depthwise convolutions and might be trained on a CPU from inferno.extensions.models import UNetBase from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D,ConvActivation class CheapConv(nn.Module): def __init__(self, in_channels, out_channels, activated): super(CheapConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels if activated: self.convs = torch.nn.Sequential( ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2), ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1)) ) else: self.convs = torch.nn.Sequential( ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2), Conv2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1)) ) def forward(self, x): assert x.shape[1] == self.in_channels,"input has wrong number of channels" x = self.convs(x) assert x.shape[1] == self.out_channels,"output has wrong number of channels" return x class CheapConvBlock(nn.Module): def __init__(self, in_channels, out_channels, activated): super(CheapConvBlock, self).__init__() self.activated = activated self.in_channels = in_channels self.out_channels = out_channels if(in_channels != out_channels): self.start = ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1)) else: self.start = None self.conv_a = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=True) self.conv_b = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=False) self.activation = torch.nn.ReLU() def forward(self, x): x_input = x if self.start is not None: x_input = self.start(x_input) x = self.conv_a(x_input) x = self.conv_b(x) x = x + x_input if self.activated: x = self.activation(x) return x class MySimple2DCpUnet(UNetBase): def __init__(self, in_channels, out_channels, depth=3, residual=False, **kwargs): super(MySimple2DCpUnet, self).__init__(in_channels=in_channels, out_channels=out_channels, dim=2, depth=depth, **kwargs) def conv_op_factory(self, in_channels, out_channels, part, index): # last? last = part == 'up' and index==0 return CheapConvBlock(in_channels=in_channels, out_channels=out_channels, activated=not last),False from inferno.extensions.layers import RemoveSingletonDimension model_b = torch.nn.Sequential( CheapConv(in_channels=image_channels, out_channels=4, activated=True), MySimple2DCpUnet(in_channels=4, out_channels=pred_channels) , RemoveSingletonDimension(dim=1) ) ################################################### # do the training (with the same functions as before) trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001) ################################################### # do the training (with the same functions as before)1 predict(trainer=trainer, test_loader=test_loader) ================================================ FILE: examples/plot_train_side_loss_unet.py ================================================ """ Train Side Loss UNet Example ================================ In this example a UNet with side supervision and auxiliary loss implemented """ ############################################################################## # Imports needed for this example import torch import torch.nn as nn from inferno.io.box.binary_blobs import get_binary_blob_loaders from inferno.trainers.basic import Trainer from inferno.extensions.layers.convolutional import Conv2D from inferno.extensions.models.res_unet import _ResBlock as ResBlock from inferno.extensions.models import ResBlockUNet from inferno.utils.torch_utils import unwrap from inferno.utils.python_utils import ensure_dir import pylab ############################################################################## # To create a UNet with side loss we create a new nn.Module class # which has a ResBlockUNet as member. # The ResBlockUNet is configured such that the results of the # bottom convolution and all the results of the up-stream # convolutions are returned as (side)-output. # a 1x1 convolutions is used to give the side outputs # the right number of out_channels and UpSampling is # used to resize all side-outputs to the full resolution # of the input. These side `side-predictions` are # returned by our MySideLossUNet. # Furthermore, all `side-predictions` are concatenated # and feed trough another two residual blocks to make # the final prediction. class MySideLossUNet(nn.Module): def __init__(self, in_channels, out_channels, depth=3): super(MySideLossUNet, self).__init__() self.depth = depth self.unet = ResBlockUNet(in_channels=in_channels, out_channels=in_channels*2, dim=2, unet_kwargs=dict(depth=depth), side_out_parts=['bottom', 'up']) # number of out channels self.n_channels_per_output = self.unet.n_channels_per_output # 1x1 conv to give the side outs of the unet # the right number of channels # and a Upsampling to give the right shape upscale_factor = 2**self.depth conv_and_scale = [] for n_channels in self.n_channels_per_output: # conv blocks conv = Conv2D(in_channels=n_channels, out_channels=out_channels, kernel_size=1) if upscale_factor > 1: upsample = nn.Upsample(scale_factor=upscale_factor) conv_and_scale.append(nn.Sequential(conv, upsample)) else: conv_and_scale.append(conv) upscale_factor //= 2 self.conv_and_scale = nn.ModuleList(conv_and_scale) # combined number of channels after concat # concat side output predictions with main output of unet self.n_channels_combined = (self.depth + 1)* out_channels + in_channels*2 self.final_block = nn.Sequential( ResBlock(dim=2,in_channels=self.n_channels_combined, out_channels=self.n_channels_combined), ResBlock(in_channels=self.n_channels_combined, out_channels=out_channels, dim=2, activated=False), ) def forward(self, input): outs = self.unet(input) assert len(outs) == len(self.n_channels_per_output) # convert the unet output into the right number of preds = [None] * len(outs) for i,out in enumerate(outs): preds[i] = self.conv_and_scale[i](out) # this is the side output preds = tuple(preds) # concat side output predictions with main output of unet combined = torch.cat(preds + (outs[-1],), 1) final_res = self.final_block(combined) # return everything return preds + (final_res,) ############################################################################## # We use a custom loss functions which applied CrossEntropyLoss # to all side outputs. # The side outputs are weighted in a quadratic fashion and added up # into a single value class MySideLoss(nn.Module): """Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion. """ def __init__(self): super(MySideLoss, self).__init__() self.criterion = nn.CrossEntropyLoss(reduce=True) w = 1.0 l = None def forward(self, predictions, target): w = 1.0 l = None for p in predictions: ll = self.criterion(p, target)*w if l is None: l = ll else: l += ll w *= 2 return l ############################################################################## # Training boilerplate (see :ref:`sphx_glr_auto_examples_trainer.py`) LOG_DIRECTORY = ensure_dir('log') SAVE_DIRECTORY = ensure_dir('save') DATASET_DIRECTORY = ensure_dir('dataset') USE_CUDA = torch.cuda.is_available() # Build a residual unet where the last layer is not activated sl_unet = MySideLossUNet(in_channels=5, out_channels=2) model = nn.Sequential( ResBlock(dim=2, in_channels=1, out_channels=5), sl_unet ) train_loader, test_loader, validate_loader = get_binary_blob_loaders( train_batch_size=3, length=512, # <= size of the images gaussian_noise_sigma=1.5 # <= how noise are the images ) # Build trainer trainer = Trainer(model) trainer.build_criterion(MySideLoss()) trainer.build_optimizer('Adam') trainer.validate_every((10, 'epochs')) #trainer.save_every((10, 'epochs')) #trainer.save_to_directory(SAVE_DIRECTORY) trainer.set_max_num_epochs(40) # Bind loaders trainer \ .bind_loader('train', train_loader)\ .bind_loader('validate', validate_loader) if USE_CUDA: trainer.cuda() # Go! trainer.fit() ############################################################################## # Predict with the trained network # and visualize the results # predict: #trainer.load(best=True) trainer.bind_loader('train', train_loader) trainer.bind_loader('validate', validate_loader) trainer.eval_mode() if USE_CUDA: trainer.cuda() # look at an example for img,target in test_loader: if USE_CUDA: img = img.cuda() # softmax on each of the prediction preds = trainer.apply_model(img) preds = [nn.functional.softmax(pred,dim=1) for pred in preds] preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds] img = unwrap(img, as_numpy=True, to_cpu=True) target = unwrap(target, as_numpy=True, to_cpu=True) n_plots = len(preds) + 2 batch_size = preds[0].shape[0] for b in range(batch_size): fig = pylab.figure() ax1 = fig.add_subplot(2,4,1) ax1.set_title('image') ax1.imshow(img[b,0,...]) ax2 = fig.add_subplot(2,4,2) ax2.set_title('ground truth') ax2.imshow(target[b,...]) for i,pred in enumerate(preds): axn = fig.add_subplot(2,4, 3+i) axn.imshow(pred[b,1,...]) if i + 1 < len(preds): axn.set_title('side prediction %d'%i) else: axn.set_title('combined prediction') pylab.show() break ================================================ FILE: examples/plot_unet_tutorial.py ================================================ """ UNet Tutorial ================================ A tentative tutorial on the usage of the unet framework in inferno """ ############################################################################## # Preface # -------------- # We start with some unspectacular multi purpose imports needed for this example import matplotlib.pyplot as plt import torch import numpy ############################################################################## # determine whether we have a gpu # and should use cuda USE_CUDA = torch.cuda.is_available() ############################################################################## # Dataset # -------------- # For simplicity we will use a toy dataset where we need to perform # a binary segmentation task. from inferno.io.box.binary_blobs import get_binary_blob_loaders # convert labels from long to float as needed by # binary cross entropy loss def label_transform(x): return torch.from_numpy(x).float() #label_transform = lambda x : torch.from_numpy(x).float() train_loader, test_loader, validate_loader = get_binary_blob_loaders( size=8, # how many images per {train,test,validate} train_batch_size=2, length=256, # <= size of the images gaussian_noise_sigma=1.4, # <= how noise are the images train_label_transform = label_transform, validate_label_transform = label_transform ) image_channels = 1 # <-- number of channels of the image pred_channels = 1 # <-- number of channels needed for the prediction ############################################################################## # Visualize Dataset # ~~~~~~~~~~~~~~~~~~~~~~ fig = plt.figure() for i,(image, target) in enumerate(train_loader): ax = fig.add_subplot(1, 2, 1) ax.imshow(image[0,0,...]) ax.set_title('raw data') ax = fig.add_subplot(1, 2, 2) ax.imshow(target[0,...]) ax.set_title('ground truth') break fig.tight_layout() plt.show() ############################################################################## # Simple UNet # ---------------------------- # We start with a very simple predefined # res block UNet. By default, this UNet uses ReLUs (in conjunction with batchnorm) as nonlinearities # With :code:`activated=False` we make sure that the last layer # is not activated since we chain the UNet with a sigmoid # activation function. from inferno.extensions.models import ResBlockUNet from inferno.extensions.layers import RemoveSingletonDimension model = torch.nn.Sequential( ResBlockUNet(dim=2, in_channels=image_channels, out_channels=pred_channels, activated=False), RemoveSingletonDimension(dim=1), torch.nn.Sigmoid() ) ############################################################################## # while the model above will work in principal, it has some drawbacks. # Within the UNet, the number of features is increased by a multiplicative # factor while going down, the so-called gain. The default value for the gain is 2. # Since we start with only a single channel we could either increase the gain, # or use a some convolutions to increase the number of channels # before the the UNet. from inferno.extensions.layers import ConvReLU2D model_a = torch.nn.Sequential( ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3), ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels, activated=False, res_block_kwargs=dict(batchnorm=True,size=2)) , RemoveSingletonDimension(dim=1) # torch.nn.Sigmoid() ) ############################################################################## # Training # ---------------------------- # To train the unet, we use the infernos Trainer class of inferno. # Since we train many models later on in this example we encapsulate # the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for # an example dedicated to the trainer itself). from inferno.trainers import Trainer from inferno.utils.python_utils import ensure_dir def train_model(model, loaders, **kwargs): trainer = Trainer(model) trainer.build_criterion('BCEWithLogitsLoss') trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001)) #trainer.validate_every((kwargs.get('validate_every', 10), 'epochs')) #trainer.save_every((kwargs.get('save_every', 10), 'epochs')) #trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor'))) trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 200)) # bind the loaders trainer.bind_loader('train', loaders[0]) trainer.bind_loader('validate', loaders[1]) if USE_CUDA: trainer.cuda() # do the training trainer.fit() return trainer trainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a', lr=0.01) ############################################################################## # Prediction # ---------------------------- # The trainer contains the trained model and we can do predictions. # We use :code:`unwrap` to convert the results to numpy arrays. # Since we want to do many prediction we encapsulate the # the prediction in a function from inferno.utils.torch_utils import unwrap def predict(trainer, test_loader, save_dir=None): trainer.eval_mode() for image, target in test_loader: # transfer image to gpu image = image.cuda() if USE_CUDA else image # get batch size from image batch_size = image.size()[0] for b in range(batch_size): prediction = trainer.apply_model(image) prediction = torch.nn.functional.sigmoid(prediction) image = unwrap(image, as_numpy=True, to_cpu=True) prediction = unwrap(prediction, as_numpy=True, to_cpu=True) target = unwrap(target, as_numpy=True, to_cpu=True) fig = plt.figure() ax = fig.add_subplot(2, 2, 1) ax.imshow(image[b,0,...]) ax.set_title('raw data') ax = fig.add_subplot(2, 2, 2) ax.imshow(target[b,...]) ax.set_title('ground truth') ax = fig.add_subplot(2, 2, 4) ax.imshow(prediction[b,...]) ax.set_title('prediction') fig.tight_layout() plt.show() ################################################### # do the prediction predict(trainer=trainer, test_loader=test_loader) ############################################################################## # Custom UNet # ---------------------------- # Often one needs to have a UNet with custom layers. # Here we show how to implement such a customized UNet. # To this end we derive from :code:`UNetBase`. # For the sake of this example we will create # a rather exotic UNet which uses different types # of convolutions/non-linearities in the different branches # of the unet from inferno.extensions.models import UNetBase from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D from inferno.extensions.layers.sampling import Upsample class MySimple2DUnet(UNetBase): def __init__(self, in_channels, out_channels, depth=3, **kwargs): super(MySimple2DUnet, self).__init__(in_channels=in_channels, out_channels=out_channels, dim=2, depth=depth, **kwargs) def conv_op_factory(self, in_channels, out_channels, part, index): if part == 'down': return torch.nn.Sequential( ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3) ), False elif part == 'bottom': return torch.nn.Sequential( ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3), ), False elif part == 'up': # are we in the very last block? if index == 0: return torch.nn.Sequential( ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), Conv2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3) ), False else: return torch.nn.Sequential( ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3) ), False else: raise RuntimeError("something is wrong") # this function CAN be implemented, if not, MaxPooling is used by default def downsample_op_factory(self, index): return torch.nn.MaxPool2d(kernel_size=2, stride=2) # this function CAN be implemented, if not, Upsampling is used by default def upsample_op_factory(self, index): return Upsample(mode='bilinear', align_corners=False,scale_factor=2) model_b = torch.nn.Sequential( ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3), MySimple2DUnet(in_channels=5, out_channels=pred_channels) , RemoveSingletonDimension(dim=1) ) ################################################### # do the training (with the same functions as before) trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001) ################################################### # do the training (with the same functions as before) predict(trainer=trainer, test_loader=test_loader) ================================================ FILE: examples/regularized_mnist.py ================================================ """ Regularized MNIST Example ================================ This example demonstrates adding and logging arbitrary regularization losses, in this case, L2 activity regularization and L1 weight regularization. - Add a `_losses` dictionary to any module containing loss names and values - Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses - Call `Trainer.observe_training_and_validation_states` to log the losses as well """ import argparse import sys import torch import torch.nn as nn from torchvision import datasets, transforms from inferno.extensions.layers.reshape import Flatten from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger class RegularizedLinear(nn.Linear): def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs): super(RegularizedLinear, self).__init__(*args, **kwargs) self.ar_weight = ar_weight self.l1_weight = l1_weight self._losses = {} def forward(self, input): output = super(RegularizedLinear, self).forward(input) self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight return output def model_fn(): return nn.Sequential( Flatten(), RegularizedLinear(in_features=784, out_features=256), nn.LeakyReLU(), RegularizedLinear(in_features=256, out_features=128), nn.LeakyReLU(), RegularizedLinear(in_features=128, out_features=10) ) def mnist_data_loaders(args): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) return train_loader, test_loader def train_model(args): model = model_fn() train_loader, validate_loader = mnist_data_loaders(args) # Build trainer trainer = Trainer(model) \ .build_criterion('RegularizedCrossEntropyLoss') \ .build_metric('CategoricalError') \ .build_optimizer('Adam') \ .validate_every((1, 'epochs')) \ .save_every((1, 'epochs')) \ .save_to_directory(args.save_directory) \ .set_max_num_epochs(args.epochs) \ .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every='never'), log_directory=args.save_directory) # Record regularization losses trainer.logger.observe_training_and_validation_states([ 'main_loss', 'total_regularization_loss', 'activity_regularization', 'l1_weight_regularization' ]) # Bind loaders trainer \ .bind_loader('train', train_loader) \ .bind_loader('validate', validate_loader) if args.cuda: trainer.cuda() # Go! trainer.fit() def main(argv): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--save-directory', type=str, default='output/mnist/v1', help='output directory') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') args = parser.parse_args(argv) args.cuda = not args.no_cuda and torch.cuda.is_available() train_model(args) if __name__ == '__main__': main(sys.argv[1:]) ================================================ FILE: examples/trainer.py ================================================ """ Trainer Example ================================ This example should illustrate how to use the trainer class. """ import torch.nn as nn from inferno.io.box.cifar import get_cifar10_loaders from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger from inferno.extensions.layers import ConvELU2D from inferno.extensions.layers import Flatten from inferno.utils.python_utils import ensure_dir from inferno.extensions.layers import SELU ################################################## # change directories to your needs LOG_DIRECTORY = ensure_dir('log') SAVE_DIRECTORY = ensure_dir('save') DATASET_DIRECTORY = ensure_dir('dataset') ################################################## # shall models be downloaded DOWNLOAD_CIFAR = True USE_CUDA = True ################################################## # Build torch model model = nn.Sequential( ConvELU2D(in_channels=3, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), Flatten(), nn.Linear(in_features=(256 * 4 * 4), out_features=10), nn.Softmax() ) ################################################## # data loaders train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY, download=DOWNLOAD_CIFAR) ################################################## # Build trainer trainer = Trainer(model) trainer.build_criterion('CrossEntropyLoss') trainer.build_metric('CategoricalError') trainer.build_optimizer('Adam') trainer.validate_every((2, 'epochs')) trainer.save_every((5, 'epochs')) trainer.save_to_directory(SAVE_DIRECTORY) trainer.set_max_num_epochs(10) trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every='never'), log_directory=LOG_DIRECTORY) ################################################## # Bind loaders trainer.bind_loader('train', train_loader) trainer.bind_loader('validate', validate_loader) ################################################## # activate cuda if USE_CUDA: trainer.cuda() ################################################## # fit trainer.fit() ================================================ FILE: inferno/__init__.py ================================================ # -*- coding: utf-8 -*- """Top-level package for inferno.""" from . import extensions from . import io from . import trainers from . import utils from .version import __version__ __all__ = ['extensions', 'io', 'trainers', 'utils'] __author__ = """Nasim Rahaman""" __email__ = 'nasim.rahaman@iwr.uni-heidelberg.de' ================================================ FILE: inferno/extensions/__init__.py ================================================ from . import containers from . import criteria from . import initializers from . import layers from . import metrics from . import optimizers from . import models # Backward support from . import models as model __all__ = ['containers', 'criteria', 'initializers', 'layers', 'metrics', 'optimizers', 'models', 'model'] ================================================ FILE: inferno/extensions/containers/__init__.py ================================================ from .graph import * from .sequential import * ================================================ FILE: inferno/extensions/containers/graph.py ================================================ from collections import OrderedDict import sys import threading import multiprocessing as mp import copy import gc import networkx as nx from networkx import is_directed_acyclic_graph, topological_sort from torch import nn as nn from ...utils import python_utils as pyu from ...utils.exceptions import assert_ from ..layers.device import OnDevice from ..layers.identity import Identity __all__ = ['NNGraph', 'Graph'] class NNGraph(nx.DiGraph): """A NetworkX DiGraph, except that node and edge ordering matters.""" # We don't copy torch tensors, only to have them deleted. ATTRIBUTES_TO_NOT_COPY = {'payload'} node_dict_factory = OrderedDict adjlist_dict_factory = OrderedDict def copy(self, **init_kwargs): new = type(self)(**init_kwargs) # Remove all attributes and copy only the graph structure for source, target in self.edges_iter(): # Add new nodes new.add_node(source) new.add_node(target) # Copy attributes new.node[source].update(copy.deepcopy({key: value for key, value in self.node[source].items() if key not in self.ATTRIBUTES_TO_NOT_COPY})) new.node[target].update(copy.deepcopy({key: value for key, value in self.node[target].items() if key not in self.ATTRIBUTES_TO_NOT_COPY})) # Add new edge new.add_edge(copy.deepcopy(source), copy.deepcopy(target)) old_edge_attributes = self[source][target] new_edge_attributes = {key: value for key, value in old_edge_attributes.items() if key not in self.ATTRIBUTES_TO_NOT_COPY} new_edge_attributes = copy.deepcopy(new_edge_attributes) new[source][target].update(new_edge_attributes) return new class Graph(nn.Module): """ A graph structure to build networks with complex architectures. The resulting graph model can be used like any other `torch.nn.Module`. The graph structure used behind the scenes is a `networkx.DiGraph`. This internal graph is exposed by the `apply_on_graph` method, which can be used with any NetworkX function (e.g. for plotting with matplotlib or GraphViz). Examples -------- The naive inception module (without the max-pooling for simplicity) with ELU-layers of 64 units can be built as following, (assuming 64 input channels): >>> from inferno.extensions.layers.reshape import Concatenate >>> from inferno.extensions.layers.convolutional import ConvELU2D >>> import torch >>> # Build the model >>> inception_module = Graph() >>> inception_module.add_input_node('input') >>> inception_module.add_node('conv1x1', ConvELU2D(64, 64, 3), previous='input') >>> inception_module.add_node('conv3x3', ConvELU2D(64, 64, 3), previous='input') >>> inception_module.add_node('conv5x5', ConvELU2D(64, 64, 3), previous='input') >>> inception_module.add_node('cat', Concatenate(), >>> previous=['conv1x1', 'conv3x3', 'conv5x5']) >>> inception_module.add_output_node('output', 'cat') >>> # Build dummy variable >>> input = torch.rand(1, 64, 100, 100) >>> # Get output >>> output = inception_module(input) """ def __init__(self, graph=None): """ Construct the graph object. Parameters ---------- graph : networkx.DiGraph or NNGraph Graph to build the object from (optional). """ super(Graph, self).__init__() # Privates self._thread_to_graph_mapping = {} self._creator_thread = threading.get_ident() self._creator_pid = mp.current_process().pid # Publics if graph is not None: self.graph = graph else: self.graph = NNGraph() @property def graph(self): # `graph` needs to be different for every thread, because torch.nn.parallel.replicate does # not make a copy. graph = self._thread_to_graph_mapping.get(threading.get_ident()) if graph is None: creator_thread_graph = self._thread_to_graph_mapping.get(self._creator_thread) assert creator_thread_graph is not None graph = creator_thread_graph.copy() # We don't need to clear payloads because the copy method of NNGraph copies only the # graph structure and not the attributes self._thread_to_graph_mapping.update({threading.get_ident(): graph}) return graph @graph.setter def graph(self, value): assert_(isinstance(value, NNGraph), exception_type=TypeError) self._thread_to_graph_mapping.update({threading.get_ident(): value}) def is_node_in_graph(self, name): """ Checks whether a node is in the graph. Parameters ---------- name : str Name of the node. Returns ------- bool """ return name in self.graph.nodes def is_source_node(self, name): """ Checks whether a given node (by name) is a source node. A source node has no incoming edges. Parameters ---------- name : str Name of the node. Returns ------- bool Raises ------ AssertionError if node is not found in the graph. """ assert self.is_node_in_graph(name) return self.graph.in_degree(name) == 0 def is_sink_node(self, name): """ Checks whether a given node (by name) is a sink node. A sink node has no outgoing edges. Parameters ---------- name : str Name of the node. Returns ------- bool Raises ------ AssertionError if node is not found in the graph. """ assert self.is_node_in_graph(name) return self.graph.out_degree(name) == 0 @property def output_nodes(self): """ Gets a list of output nodes. The order is relevant and is the same as that in which the forward method returns its outputs. Returns ------- list A list of names (str) of the output nodes. """ return [name for name, node_attributes in self.graph.nodes.items() if node_attributes.get('is_output_node', False)] @property def input_nodes(self): """ Gets a list of input nodes. The order is relevant and is the same as that in which the forward method accepts its inputs. Returns ------- list A list of names (str) of the input nodes. """ return [name for name, node_attributes in self.graph.nodes.items() if node_attributes.get('is_input_node', False)] @property def graph_is_valid(self): """Checks if the graph is valid.""" # Check if the graph is a DAG is_dag = is_directed_acyclic_graph(self.graph) # Check if output nodes are sinks output_nodes_are_sinks = all([self.is_sink_node(name) for name in self.output_nodes]) # Check inf input nodes are sources input_nodes_are_sources = all([self.is_source_node(name) for name in self.input_nodes]) # TODO Check whether only input nodes are sources and only output nodes are sinks # Conclude is_valid = is_dag and output_nodes_are_sinks and input_nodes_are_sources return is_valid def assert_graph_is_valid(self): """Asserts that the graph is valid.""" assert is_directed_acyclic_graph(self.graph), "Graph is not a DAG." for name in self.output_nodes: assert self.is_sink_node(name), "Output node {} is not a sink.".format(name) assert not self.is_source_node(name), "Output node {} is a source node. " \ "Make sure it's connected.".format(name) for name in self.input_nodes: assert self.is_source_node(name), "Input node {} is not a source.".format(name) assert not self.is_sink_node(name), "Input node {} is a sink node. " \ "Make sure it's connected.".format(name) def add_node(self, name, module, previous=None): """ Add a node to the graph. Parameters ---------- name : str Name of the node. Nodes are identified by their names. module : torch.nn.Module Torch module for this node. previous : str or list of str (List of) name(s) of the previous node(s). Returns ------- Graph self """ assert isinstance(module, nn.Module) self.add_module(name, module) self.graph.add_node(name) if previous is not None: for _previous in pyu.to_iterable(previous): self.add_edge(_previous, name) return self def add_input_node(self, name): """ Add an input to the graph. The order in which input nodes are added is the order in which the forward method accepts its inputs. Parameters ---------- name : str Name of the input node. Returns ------- Graph self """ self.add_module(name, Identity()) self.graph.add_node(name, is_input_node=True) return self def add_output_node(self, name, previous=None): """ Add an output to the graph. The order in which output nodes are added is the order in which the forward method returns its outputs. Parameters ---------- name : str Name of the output node. Returns ------- Graph self """ self.graph.add_node(name, is_output_node=True) if previous is not None: for _previous in pyu.to_iterable(previous): self.add_edge(_previous, name) return self def add_edge(self, from_node, to_node): """ Add an edge between two nodes. Parameters ---------- from_node : str Name of the source node. to_node : str Name of the target node. Returns ------- Graph self Raises ------ AssertionError if either of the two nodes is not in the graph, or if the edge is not 'legal'. """ assert self.is_node_in_graph(from_node) assert self.is_node_in_graph(to_node) self.graph.add_edge(from_node, to_node) assert self.graph_is_valid return self def apply_on_graph(self, function, *args, **kwargs): """Applies a `function` on the internal graph.""" return function(self, *args, **kwargs) def get_module_for_nodes(self, names): """ Gets the `torch.nn.Module` object for nodes corresponding to `names`. Parameters ---------- names : str or list of str Names of the nodes to fetch the modules of. Returns ------- list or torch.nn.Module Module or a list of modules corresponding to `names`. """ names = pyu.to_iterable(names) modules = [] for name in names: assert self.is_node_in_graph(name), "Node '{}' is not in graph.".format(name) module = getattr(self, name, None) assert module is not None, "Node '{}' is in the graph but could not find a module " \ "corresponding to it.".format(name) modules.append(module) return pyu.from_iterable(modules) def to_device(self, names, target_device, device_ordinal=None, asynchronous=False): """Transfer nodes in the network to a specified device.""" names = pyu.to_iterable(names) for name in names: assert self.is_node_in_graph(name), "Node '{}' is not in graph.".format(name) module = getattr(self, name, None) assert module is not None, "Node '{}' is in the graph but could not find a module " \ "corresponding to it.".format(name) # Transfer module_on_device = OnDevice(module, target_device, device_ordinal=device_ordinal, asynchronous=asynchronous) setattr(self, name, module_on_device) return self def get_parameters_for_nodes(self, names, named=False): """Get parameters of all nodes listed in `names`.""" if not named: parameters = (parameter for module in pyu.to_iterable(self.get_module_for_nodes(names)) for parameter in module.parameters()) else: parameters = ((name, parameter) for module in pyu.to_iterable(self.get_module_for_nodes(names)) for name, parameter in module.named_parameters()) return parameters def clear_payloads(self, graph=None): graph = self.graph if graph is None else graph for edge in list(graph.edges(data=True)): source, target, _ = edge if 'payload' in graph[source][target]: del graph[source][target]['payload'] def forward_through_node(self, name, input=None): # If input is a tuple/list, it will NOT be unpacked. # Make sure the node is in the graph if input is None: # Make sure the node is not a source node assert not self.is_source_node(name), \ "Node '{}' did not get an input but is a source node.".format(name) # Get input from payload incoming_edges = self.graph.in_edges(name) input = [] for incoming, this in incoming_edges: # Append to input input.append(self.graph[incoming][this]['payload']) # Clear reference for the garbage collector to do its thing del self.graph[incoming][this]['payload'] else: assert self.is_node_in_graph(name) # Convert input to list input = [input] # Get outputs try: outputs = pyu.to_iterable(getattr(self, name)(*input)) except Exception as e: input_spec_string = "\n".join(["--[{}]-{}-->[{}]".format(incoming, tuple(_input.size()), this) for (incoming, this), _input in zip(self.graph.in_edges(name), input)]) message = "In node '{}': {}\n" \ "Inputs to this node were:\n{}"\ .format(name, str(e), input_spec_string) raise type(e)(message).with_traceback(sys.exc_info()[2]) # Distribute outputs to outgoing payloads if required if not self.is_sink_node(name): outgoing_edges = self.graph.out_edges(name) if len(outputs) == 1: # Support for replication outputs *= len(outgoing_edges) # Make sure the number of outputs check out assert len(outputs) == len(outgoing_edges), \ "Number of outputs from the model ({}) does not match the number " \ "of out-edges ({}) in the graph for this node ('{}').".format(len(outputs), len(outgoing_edges), name) for (this, outgoing), output in zip(outgoing_edges, outputs): self.graph[this][outgoing].update({'payload': output}) # Collect garbage to free some GPU memory? del input gc.collect() # Return outputs return pyu.from_iterable(outputs) def forward(self, *inputs): self.assert_graph_is_valid() input_nodes = self.input_nodes output_nodes = self.output_nodes assert len(inputs) == len(input_nodes), "Was expecting {} " \ "arguments for as many input nodes, got {}."\ .format(len(input_nodes), len(inputs)) # Unpack inputs to input nodes for input, input_node in zip(inputs, input_nodes): self.forward_through_node(input_node, input=input) # Toposort the graph toposorted = topological_sort(self.graph) # Remove all input and output nodes toposorted = [name for name in toposorted if name not in input_nodes and name not in output_nodes] # Since we'll be clearing payloads anyway, it makes no sense whatsoever # to evaluate sink nodes toposorted = [name for name in toposorted if not self.is_sink_node(name)] # Forward for node in toposorted: self.forward_through_node(node) # Read outputs from output nodes outputs = [] for output_node in output_nodes: # Get all incoming edges to output node outputs_from_node = [self.graph[incoming][this]['payload'] for incoming, this in self.graph.in_edges(output_node)] outputs.append(pyu.from_iterable(outputs_from_node)) # Clear payloads for next pass self.clear_payloads() # Done. return pyu.from_iterable(outputs) ================================================ FILE: inferno/extensions/containers/sequential.py ================================================ import torch.nn as nn from ...utils import python_utils as pyu __all__ = ['Sequential1', 'Sequential2'] class Sequential1(nn.Sequential): """Like torch.nn.Sequential, but with a few extra methods.""" def __len__(self): return len(self._modules.values()) class Sequential2(Sequential1): """Another sequential container. Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and accept multiple inputs. """ def forward(self, *input): for module in self._modules.values(): input = pyu.to_iterable(module(*pyu.to_iterable(input))) return pyu.from_iterable(input) ================================================ FILE: inferno/extensions/criteria/__init__.py ================================================ from .set_similarity_measures import * from .elementwise_measures import * from .core import * from .regularized import * __all__ = ['set_similarity_measures', 'elementwise_measures','core','regularized'] ================================================ FILE: inferno/extensions/criteria/core.py ================================================ import torch.nn as nn from functools import reduce from ...utils.exceptions import assert_, ShapeError, NotTorchModuleError __all__ = ['Criteria', 'As2DCriterion'] class Criteria(nn.Module): """Aggregate multiple criteria to one.""" def __init__(self, *criteria): super(Criteria, self).__init__() if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)): criteria = list(criteria[0]) else: criteria = list(criteria) # Validate criteria assert all([isinstance(criterion, nn.Module) for criterion in criteria]), \ "Criterion must be a torch module." self.criteria = criteria def forward(self, prediction, target): assert isinstance(prediction, (list, tuple)), \ "`prediction` must be a list or a tuple, got {} instead."\ .format(type(prediction).__name__) assert isinstance(target, (list, tuple)), \ "`prediction` must be a list or a tuple, got {} instead." \ .format(type(target).__name__) assert len(prediction) == len(target), \ "Number of predictions must equal the number of targets. " \ "Got {} predictions but {} targets.".format(len(prediction), len(target)) # Compute losses losses = [criterion(prediction, target) for _prediction, _target, criterion in zip(prediction, target, self.criteria)] # Aggegate losses loss = reduce(lambda x, y: x + y, losses) # Done return loss class As2DCriterion(nn.Module): """ Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors, if they're applicable to (N, C) prediction and (N,) target tensors . """ def __init__(self, criterion): super(As2DCriterion, self).__init__() assert_(isinstance(criterion, nn.Module), "Criterion must be a module, got a {} instead." .format(type(criterion).__name__), NotTorchModuleError) self.criterion = criterion def forward(self, prediction, target): # Validate input assert_(prediction.dim() == 4, "`prediction` is expected to be a 4D tensor of shape " "(N, C, H, W), got a {}D " "tensor instead.".format(prediction.dim()), ShapeError) assert_(target.dim() == 3, "`target` is expected to be a 3D tensor of shape " "(N, H, W), got a {}D " "tensor instead.".format(target.dim()), ShapeError) # prediction is assumed to be NCHW, and target NHW. # this makes target (NHW,) target = target.contiguous().view(-1) # This makes prediction (N, H, W, C) --> (NHW, C) num_channels = prediction.size(1) prediction = prediction.permute(0, 2, 3, 1).contiguous().view(-1, num_channels) # Now, the criterion should be applicable as is loss = self.criterion(prediction, target) return loss ================================================ FILE: inferno/extensions/criteria/elementwise_measures.py ================================================ import torch.nn as nn from ...utils.exceptions import assert_ class WeightedMSELoss(nn.Module): NEGATIVE_CLASS_WEIGHT = 1. def __init__(self, positive_class_weight=1., positive_class_value=1., size_average=True): super(WeightedMSELoss, self).__init__() assert_(positive_class_weight >= 0, "Positive class weight can't be less than zero, got {}." .format(positive_class_weight), ValueError) self.mse = nn.MSELoss(size_average=size_average) self.positive_class_weight = positive_class_weight self.positive_class_value = positive_class_value def forward(self, input, target): # Get a mask positive_class_mask = target.data.eq(self.positive_class_value).type_as(target.data) # Get differential weights (positive_weight - negative_weight, # i.e. subtract 1, assuming the negative weight is gauged at 1) weight_differential = (positive_class_mask .mul_(self.positive_class_weight - self.NEGATIVE_CLASS_WEIGHT)) # Get final weight by adding weight differential to a tensor with negative weights weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT) # `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with. sqrt_weights = weights.sqrt_() return self.mse(input * sqrt_weights, target * sqrt_weights) ================================================ FILE: inferno/extensions/criteria/regularized.py ================================================ import warnings import torch from torch import nn from . import set_similarity_measures, core __all__ = [ 'RegularizedLoss', 'RegularizedCrossEntropyLoss', 'RegularizedBCEWithLogitsLoss', 'RegularizedBCELoss', 'RegularizedMSELoss', 'RegularizedNLLLoss' ] def collect_losses(module): """Collect `_losses` dictionaries from module and children :param module: a Module to be searched for losses :return: dictionary of loss names to values """ losses = {} def _collect(m): if hasattr(m, '_losses'): for k, v in m._losses.items(): if k in losses: losses[k] = losses[k] + v else: losses[k] = v module.apply(_collect) return losses def build_criterion(criterion, *args, **kwargs): """Build a criterion :param criterion: criterion class, name of criterion class, or instance of criterion :param args: args for constructor :param kwargs: kwargs for constructor :return: instance of criterion """ if isinstance(criterion, str): for module in [nn, core, set_similarity_measures]: criterion_class = getattr(module, criterion, None) if criterion_class is not None: break assert criterion_class is not None, "Criterion {} not found.".format(criterion) elif callable(criterion) and isinstance(criterion, type): criterion_class = criterion elif isinstance(criterion, torch.nn.Module): return criterion else: raise NotImplementedError return criterion_class(*args, **kwargs) class RegularizedLoss(nn.Module): """Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion. """ def __init__(self, criterion, *args, **kwargs): super(RegularizedLoss, self).__init__() self.criterion = build_criterion(criterion, *args, **kwargs) def forward(self, *args, trainer=None, model=None, **kwargs): # calculate wrapped loss main_loss = self.criterion(*args, **kwargs) # If no trainer, we cannot record states if trainer is None: warnings.warn('No trainer parameter provided. Not logging regularization losses.') elif model is None: model = trainer.model # If no model or trainer, we cannot record states or collect losses if model is None: warnings.warn('No model or trainer parameter provided. Not calculating regularization losses.') regularization_losses = {} total_regularization_loss = None total_loss = main_loss else: regularization_losses = collect_losses(model) total_regularization_loss = sum(regularization_losses.values()) total_loss = main_loss + total_regularization_loss # Record losses if trainer provided if trainer is not None: # prefix depending on mode if self.training: prefix = 'training' else: prefix = 'validation' # main loss updates = {'{}_main_loss'.format(prefix): main_loss} # total regulariztion loss if total_regularization_loss is not None: updates['{}_total_regularization_loss'.format(prefix)] = total_regularization_loss # detailed regularization losses for k, v in regularization_losses.items(): updates['{}_{}'.format(prefix, k)] = v # record state trainer.update_state_from_dictionary(updates) return total_loss # Convenience wrappers for common losses class RegularizedCrossEntropyLoss(RegularizedLoss): def __init__(self, *args, **kwargs): super(RegularizedCrossEntropyLoss, self).__init__(nn.CrossEntropyLoss, *args, **kwargs) class RegularizedBCEWithLogitsLoss(RegularizedLoss): def __init__(self, *args, **kwargs): super(RegularizedBCEWithLogitsLoss, self).__init__(nn.BCEWithLogitsLoss, *args, **kwargs) class RegularizedBCELoss(RegularizedLoss): def __init__(self, *args, **kwargs): super(RegularizedBCELoss, self).__init__(nn.BCELoss, *args, **kwargs) class RegularizedMSELoss(RegularizedLoss): def __init__(self, *args, **kwargs): super(RegularizedMSELoss, self).__init__(nn.MSELoss, *args, **kwargs) class RegularizedNLLLoss(RegularizedLoss): def __init__(self, *args, **kwargs): super(RegularizedNLLLoss, self).__init__(nn.NLLLoss, *args, **kwargs) ================================================ FILE: inferno/extensions/criteria/set_similarity_measures.py ================================================ import torch.nn as nn from ...utils.torch_utils import flatten_samples __all__ = ['SorensenDiceLoss', 'GeneralizedDiceLoss'] class SorensenDiceLoss(nn.Module): """ Computes a loss scalar, which when minimized maximizes the Sorensen-Dice similarity between the input and the target. For both inputs and targets it must be the case that `input_or_target.size(1) = num_channels`. """ def __init__(self, weight=None, channelwise=True, eps=1e-6): """ Parameters ---------- weight : torch.FloatTensor or torch.cuda.FloatTensor Class weights. Applies only if `channelwise = True`. channelwise : bool Whether to apply the loss channelwise and sum the results (True) or to apply it on all channels jointly (False). """ super(SorensenDiceLoss, self).__init__() self.register_buffer('weight', weight) self.channelwise = channelwise self.eps = eps def forward(self, input, target): """ input: torch.FloatTensor or torch.cuda.FloatTensor target: torch.FloatTensor or torch.cuda.FloatTensor Expected shape of the inputs: (batch_size, nb_channels, ...) """ assert input.size() == target.size() if not self.channelwise: numerator = (input * target).sum() denominator = (input * input).sum() + (target * target).sum() loss = -2. * (numerator / denominator.clamp(min=self.eps)) else: # TODO This should be compatible with Pytorch 0.2, but check # Flatten input and target to have the shape (C, N), # where N is the number of samples input = flatten_samples(input) target = flatten_samples(target) # Compute numerator and denominator (by summing over samples and # leaving the channels intact) numerator = (input * target).sum(-1) denominator = (input * input).sum(-1) + (target * target).sum(-1) channelwise_loss = -2 * (numerator / denominator.clamp(min=self.eps)) if self.weight is not None: # With pytorch < 0.2, channelwise_loss.size = (C, 1). if channelwise_loss.dim() == 2: channelwise_loss = channelwise_loss.squeeze(1) assert self.weight.size() == channelwise_loss.size() # Apply weight channelwise_loss = self.weight * channelwise_loss # Sum over the channels to compute the total loss loss = channelwise_loss.sum() return loss class GeneralizedDiceLoss(nn.Module): """ Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237 This version works for multiple classes and expects predictions for every class (e.g. softmax output) and one-hot targets for every class. """ def __init__(self, weight=None, channelwise=False, eps=1e-6): super(GeneralizedDiceLoss, self).__init__() self.register_buffer('weight', weight) self.channelwise = channelwise self.eps = eps def forward(self, input, target): """ input: torch.FloatTensor or torch.cuda.FloatTensor target: torch.FloatTensor or torch.cuda.FloatTensor Expected shape of the inputs: - if not channelwise: (batch_size, nb_classes, ...) - if channelwise: (batch_size, nb_channels, nb_classes, ...) """ assert input.size() == target.size() if not self.channelwise: # Flatten input and target to have the shape (nb_classes, N), # where N is the number of samples input = flatten_samples(input) target = flatten_samples(target) # Find classes weights: sum_targets = target.sum(-1) class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps) # Compute generalized Dice loss: numer = ((input * target).sum(-1) * class_weigths).sum() denom = ((input + target).sum(-1) * class_weigths).sum() loss = 1. - 2. * numer / denom.clamp(min=self.eps) else: def flatten_and_preserve_channels(tensor): tensor_dim = tensor.dim() assert tensor_dim >= 3 num_channels = tensor.size(1) num_classes = tensor.size(2) # Permute the channel axis to first permute_axes = list(range(tensor_dim)) permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0] permuted = tensor.permute(*permute_axes).contiguous() flattened = permuted.view(num_channels, num_classes, -1) return flattened # Flatten input and target to have the shape (nb_channels, nb_classes, N) input = flatten_and_preserve_channels(input) target = flatten_and_preserve_channels(target) # Find classes weights: sum_targets = target.sum(-1) class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps) # Compute generalized Dice loss: numer = ((input * target).sum(-1) * class_weigths).sum(-1) denom = ((input + target).sum(-1) * class_weigths).sum(-1) channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps) if self.weight is not None: if channelwise_loss.dim() == 2: channelwise_loss = channelwise_loss.squeeze(1) assert self.weight.size() == channelwise_loss.size(),\ """`weight` should have shape (nb_channels, ), `target` should have shape (batch_size, nb_channels, nb_classes, ...)""" # Apply channel weights: channelwise_loss = self.weight * channelwise_loss loss = channelwise_loss.sum() return loss ================================================ FILE: inferno/extensions/initializers/__init__.py ================================================ from .base import * from .presets import * ================================================ FILE: inferno/extensions/initializers/base.py ================================================ import torch.nn.init as init __all__ = ['Initializer', 'Initialization', 'WeightInitFunction', 'BiasInitFunction', 'TensorInitFunction'] class Initializer(object): """ Base class for all initializers. """ # TODO Support LSTMs and GRUs VALID_LAYERS = {'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', 'Linear', 'Bilinear', 'Embedding'} def __call__(self, module): module_class_name = module.__class__.__name__ if module_class_name in self.VALID_LAYERS: # Apply to weight and bias try: if hasattr(module, 'weight'): self.call_on_weight(module.weight.data) except NotImplementedError: # Don't cry if it's not implemented pass try: if hasattr(module, 'bias'): self.call_on_bias(module.bias.data) except NotImplementedError: pass return module def call_on_bias(self, tensor): return self.call_on_tensor(tensor) def call_on_weight(self, tensor): return self.call_on_tensor(tensor) def call_on_tensor(self, tensor): raise NotImplementedError @classmethod def initializes_weight(cls): return 'call_on_tensor' in cls.__dict__ or 'call_on_weight' in cls.__dict__ @classmethod def initializes_bias(cls): return 'call_on_tensor' in cls.__dict__ or 'call_on_bias' in cls.__dict__ class Initialization(Initializer): def __init__(self, weight_initializer=None, bias_initializer=None): if weight_initializer is None: self.weight_initializer = Initializer() else: if isinstance(weight_initializer, Initializer): assert weight_initializer.initializes_weight() self.weight_initializer = weight_initializer elif isinstance(weight_initializer, str): init_function = getattr(init, weight_initializer, None) assert init_function is not None self.weight_initializer = WeightInitFunction(init_function=init_function) else: # Provison for weight_initializer to be a function assert callable(weight_initializer) self.weight_initializer = WeightInitFunction(init_function=weight_initializer) if bias_initializer is None: self.bias_initializer = Initializer() else: if isinstance(bias_initializer, Initializer): assert bias_initializer.initializes_bias self.bias_initializer = bias_initializer elif isinstance(bias_initializer, str): init_function = getattr(init, bias_initializer, None) assert init_function is not None self.bias_initializer = BiasInitFunction(init_function=init_function) else: assert callable(bias_initializer) self.bias_initializer = BiasInitFunction(init_function=bias_initializer) def call_on_weight(self, tensor): return self.weight_initializer.call_on_weight(tensor) def call_on_bias(self, tensor): return self.bias_initializer.call_on_bias(tensor) class WeightInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(WeightInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs def call_on_weight(self, tensor): return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs) class BiasInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(BiasInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs def call_on_bias(self, tensor): return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs) class TensorInitFunction(Initializer): def __init__(self, init_function, *init_function_args, **init_function_kwargs): super(TensorInitFunction, self).__init__() assert callable(init_function) self.init_function = init_function self.init_function_args = init_function_args self.init_function_kwargs = init_function_kwargs def call_on_tensor(self, tensor): return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs) ================================================ FILE: inferno/extensions/initializers/presets.py ================================================ import numpy as np import torch.nn.init as init from functools import partial from .base import Initialization, Initializer __all__ = ['Constant', 'NormalWeights', 'SELUWeightsZeroBias', 'ELUWeightsZeroBias', 'OrthogonalWeightsZeroBias', 'KaimingNormalWeightsZeroBias'] class Constant(Initializer): """Initialize with a constant.""" def __init__(self, constant): self.constant = constant def call_on_tensor(self, tensor): tensor.fill_(self.constant) return tensor class NormalWeights(Initializer): """ Initialize weights with random numbers drawn from the normal distribution at `mean` and `stddev`. """ def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None): self.mean = mean self.stddev = stddev self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in def compute_fan_in(self, tensor): if tensor.dim() == 2: return tensor.size(1) else: return np.prod(list(tensor.size())[1:]) def call_on_weight(self, tensor): # Compute stddev if required if self.sqrt_gain_over_fan_in is not None: stddev = self.stddev * \ np.sqrt(self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor)) else: stddev = self.stddev # Init tensor.normal_(self.mean, stddev) class OrthogonalWeightsZeroBias(Initialization): def __init__(self, orthogonal_gain=1.): # This prevents a deprecated warning in Pytorch 0.4+ orthogonal = getattr(init, 'orthogonal_', init.orthogonal) super(OrthogonalWeightsZeroBias, self)\ .__init__(weight_initializer=partial(orthogonal, gain=orthogonal_gain), bias_initializer=Constant(0.)) class KaimingNormalWeightsZeroBias(Initialization): def __init__(self, relu_leakage=0): # This prevents a deprecated warning in Pytorch 0.4+ kaiming_normal = getattr(init, 'kaiming_normal_', init.kaiming_normal) super(KaimingNormalWeightsZeroBias, self)\ .__init__(weight_initializer=partial(kaiming_normal, a=relu_leakage), bias_initializer=Constant(0.)) class SELUWeightsZeroBias(Initialization): def __init__(self): super(SELUWeightsZeroBias, self)\ .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.), bias_initializer=Constant(0.)) class ELUWeightsZeroBias(Initialization): def __init__(self): super(ELUWeightsZeroBias, self)\ .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277), bias_initializer=Constant(0.)) ================================================ FILE: inferno/extensions/layers/__init__.py ================================================ __all__ = [] from .activations import * from .convolutional import * from .device import * from .reshape import * from .convolutional_blocks import * ####################################################### # the following is to make the sphinx example # gallery makes proper cross-references from .activations import _all as _activations_all from .convolutional import _all as _convolutional_all from .device import _all as _device_all from .reshape import _all as _reshape_all from .convolutional_blocks import _all as _convolutional_blocks_all from .identity import _all as _identity_all __all__.extend(_activations_all) __all__.extend(_convolutional_all) __all__.extend(_device_all) __all__.extend(_reshape_all) __all__.extend(_convolutional_blocks_all) __all__.extend(_identity_all) _all = __all__ ================================================ FILE: inferno/extensions/layers/activations.py ================================================ import torch.nn.functional as F import torch.nn as nn from ...utils.torch_utils import where __all__ = ['SELU'] _all = __all__ class SELU(nn.Module): def forward(self, input): return self.selu(input) @staticmethod def selu(x): alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 # noinspection PyTypeChecker return scale * where(x >= 0, x, alpha * F.elu(x)) ================================================ FILE: inferno/extensions/layers/convolutional.py ================================================ import torch.nn as nn import sys import functools from ..initializers import ( OrthogonalWeightsZeroBias, KaimingNormalWeightsZeroBias, SELUWeightsZeroBias, ) from ..initializers import Initializer from .normalization import BatchNormND from .activations import SELU from ...utils.exceptions import assert_, ShapeError from ...utils.partial_cls import register_partial_cls # we append to this later on __all__ = [ "GlobalConv2D", ] _all = __all__ register_partial_cls_here = functools.partial(register_partial_cls, module=__name__) class ConvActivation(nn.Module): """Convolutional layer with 'SAME' padding by default followed by an activation.""" def __init__( self, in_channels, out_channels, kernel_size, dim, activation, stride=1, dilation=1, groups=None, depthwise=False, bias=True, deconv=False, initialization=None, valid_conv=False, ): super(ConvActivation, self).__init__() # Validate dim assert_( dim in [1, 2, 3], "`dim` must be one of [1, 2, 3], got {}.".format(dim), ShapeError, ) self.dim = dim # Check if depthwise if depthwise: # We know that in_channels == out_channels, but we also want a consistent API. # As a compromise, we allow that out_channels be None or 'auto'. out_channels = in_channels if out_channels in [None, "auto"] else out_channel assert_( in_channels == out_channels, "For depthwise convolutions, number of input channels (given: {}) " "must equal the number of output channels (given {}).".format( in_channels, out_channels ), ValueError, ) assert_( groups is None or groups == in_channels, "For depthwise convolutions, groups (given: {}) must " "equal the number of channels (given: {}).".format(groups, in_channels), ) groups = in_channels else: groups = 1 if groups is None else groups self.depthwise = depthwise if valid_conv: self.conv = getattr(nn, "Conv{}d".format(self.dim))( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, ) elif not deconv: # Get padding padding = self.get_padding(kernel_size, dilation) self.conv = getattr(nn, "Conv{}d".format(self.dim))( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, dilation=dilation, groups=groups, bias=bias, ) else: self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, ) if initialization is None: pass elif isinstance(initialization, Initializer): self.conv.apply(initialization) else: raise NotImplementedError if isinstance(activation, str): self.activation = getattr(nn, activation)() elif isinstance(activation, nn.Module): self.activation = activation elif activation is None: self.activation = None else: raise NotImplementedError def forward(self, input): conved = self.conv(input) if self.activation is not None: activated = self.activation(conved) else: # No activation activated = conved return activated def _pair_or_triplet(self, object_): if isinstance(object_, (list, tuple)): assert len(object_) == self.dim return object_ else: object_ = [object_] * self.dim return object_ def _get_padding(self, _kernel_size, _dilation): assert isinstance(_kernel_size, int) assert isinstance(_dilation, int) assert _kernel_size % 2 == 1 return ((_kernel_size - 1) // 2) * _dilation def get_padding(self, kernel_size, dilation): kernel_size = self._pair_or_triplet(kernel_size) dilation = self._pair_or_triplet(dilation) padding = [ self._get_padding(_kernel_size, _dilation) for _kernel_size, _dilation in zip(kernel_size, dilation) ] return tuple(padding) # for consistency ConvActivationND = ConvActivation # noinspection PyUnresolvedReferences class _BNReLUSomeConv(object): def forward(self, input): normed = self.batchnorm(input) activated = self.activation(normed) conved = self.conv(activated) return conved class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation): def __init__(self, in_channels, out_channels, kernel_size, dim, stride=1, dilation=1, deconv=False): super(BNReLUConvBaseND, self).__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dim=dim, stride=stride, activation=nn.ReLU(inplace=True), dilation=dilation, deconv=deconv, initialization=KaimingNormalWeightsZeroBias(0), ) self.batchnorm = BatchNormND(dim, in_channels) def _register_conv_cls(conv_name, fix=None, default=None): if fix is None: fix = {} if default is None: default = {} # simple conv activation activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""] init_map = { "ReLU": KaimingNormalWeightsZeroBias, "SELU": SELUWeightsZeroBias } for activation_str in activations: cls_name = cls_name = "{}{}ND".format(conv_name,activation_str) __all__.append(cls_name) initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias) if activation_str == "": activation = None _fix = {**fix} _default = {'activation':None} elif activation_str == "SELU": activation = nn.SELU(inplace=True) _fix={**fix, 'activation':activation} _default = {**default} else: activation = activation_str _fix={**fix, 'activation':activation} _default = {**default} register_partial_cls_here(ConvActivation, cls_name, fix=_fix, default={**_default, 'initialization':initialization_cls()} ) for dim in [1, 2, 3]: cls_name = "{}{}{}D".format(conv_name,activation_str, dim) __all__.append(cls_name) register_partial_cls_here(ConvActivation, cls_name, fix={**_fix, 'dim':dim}, default={**_default, 'initialization':initialization_cls()} ) def _register_bnr_conv_cls(conv_name, fix=None, default=None): if fix is None: fix = {} if default is None: default = {} for dim in [1, 2, 3]: cls_name = "BNReLU{}ND".format(conv_name) __all__.append(cls_name) register_partial_cls_here(BNReLUConvBaseND, cls_name,fix=fix,default=default) for dim in [1, 2, 3]: cls_name = "BNReLU{}{}D".format(conv_name, dim) __all__.append(cls_name) register_partial_cls_here(BNReLUConvBaseND, cls_name, fix={**fix, 'dim':dim}, default=default) # conv classes _register_conv_cls("Conv") _register_conv_cls("ValidConv", fix=dict(valid_conv=True)) _register_conv_cls("Deconv", fix=dict(deconv=True), default=dict(kernel_size=2, stride=2)) _register_conv_cls("StridedConv", default=dict(stride=2)) _register_conv_cls("DilatedConv", fix=dict(dilation=2)) _register_conv_cls("DepthwiseConv", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto')) # BatchNormRelu classes _register_bnr_conv_cls("Conv", fix=dict(deconv=False)) _register_bnr_conv_cls("Deconv", fix=dict(deconv=True)) _register_bnr_conv_cls("StridedConv", default=dict(stride=2)) _register_bnr_conv_cls("DilatedConv", default=dict(dilation=2)) _register_bnr_conv_cls("DepthwiseConv", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto')) del _register_conv_cls del _register_bnr_conv_cls class GlobalConv2D(nn.Module): """From https://arxiv.org/pdf/1703.02719.pdf Main idea: we can have a bigger kernel size computationally acceptable if we separate 2D-conv in 2 1D-convs """ def __init__( self, in_channels, out_channels, kernel_size, local_conv_type, activation=None, use_BN=False, **kwargs ): super(GlobalConv2D, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size assert isinstance(kernel_size, (int, list, tuple)) if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 2 self.kwargs = kwargs self.conv1a = local_conv_type( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=(kernel_size[0], 1), **kwargs ) self.conv1b = local_conv_type( in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=(1, kernel_size[1]), **kwargs ) self.conv2a = local_conv_type( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=(1, kernel_size[1]), **kwargs ) self.conv2b = local_conv_type( in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=(kernel_size[0], 1), **kwargs ) if use_BN: self.batchnorm = nn.BatchNorm2d(self.out_channels) else: self.batchnorm = None self.activation = activation def forward(self, input_): out1 = self.conv1a(input_) out1 = self.conv1b(out1) out2 = self.conv2a(input_) out2 = self.conv2b(out2) out = out1.add(1, out2) if self.activation is not None: out = self.activation(out) if self.batchnorm is not None: out = self.batchnorm(out) return out ================================================ FILE: inferno/extensions/layers/convolutional_blocks.py ================================================ import torch.nn as nn from .convolutional import BNReLUConv2D, BNReLUDeconv2D, Conv2D, Deconv2D from ...utils import python_utils as pyu from ...utils.exceptions import assert_ __all__ = ['ResidualBlock', 'PreActSimpleResidualBlock'] _all = __all__ class ResidualBlock(nn.Module): def __init__(self, layers, resample=None): super(ResidualBlock, self).__init__() assert pyu.is_listlike(layers) self.layers = nn.Sequential(*layers) self.resample = resample def forward(self, input): preaddition = self.layers(input) if self.resample is not None: skip = self.resample(input) else: skip = input output = preaddition + skip return output class PreActSimpleResidualBlock(ResidualBlock): def __init__(self, in_channels, num_hidden_channels, upsample=False, downsample=False): layers = [] if downsample: assert_(not upsample, "Both downsample and upsample is set to true.", ValueError) layers.append(BNReLUConv2D(in_channels=in_channels, out_channels=num_hidden_channels, kernel_size=3, stride=2)) resample = nn.Sequential(Conv2D(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=2), nn.BatchNorm2d(in_channels)) elif upsample: layers.append(BNReLUDeconv2D(in_channels=in_channels, out_channels=num_hidden_channels, kernel_size=2, stride=2)) resample = nn.Sequential(Deconv2D(in_channels=in_channels, out_channels=in_channels, kernel_size=2, stride=2), nn.BatchNorm2d(in_channels)) else: layers.append(BNReLUConv2D(in_channels=in_channels, out_channels=num_hidden_channels, kernel_size=3)) resample = None layers.append(BNReLUConv2D(in_channels=num_hidden_channels, out_channels=in_channels, kernel_size=3)) super(PreActSimpleResidualBlock, self).__init__(layers, resample) # TODO PreActBottleneckResidualBlock ================================================ FILE: inferno/extensions/layers/device.py ================================================ import torch.nn as nn from ...utils.python_utils import from_iterable, to_iterable from ...utils.exceptions import assert_, DeviceError __all__ = ['DeviceTransfer', 'OnDevice'] _all = __all__ class DeviceTransfer(nn.Module): """Layer to transfer variables to a specified device.""" def __init__(self, target_device, device_ordinal=None, asynchronous=False): """ Parameters ---------- target_device : {'cpu', 'cuda'} Device to transfer to. device_ordinal : int Device ordinal if target_device == 'cuda'. asynchronous : bool Whether to use asynchronous transfers. """ super(DeviceTransfer, self).__init__() # Validate arguments assert_(target_device in ['cpu', 'cuda'], "Target device must either be 'cpu' or 'cuda'.", DeviceError) if target_device == 'cpu': assert_(device_ordinal is None, "'device_ordinal' must be None if target_device is 'cpu'.", DeviceError) self.target_device = target_device self.device_ordinal = device_ordinal def forward(self, *inputs): if self.target_device == 'cuda': transferred = tuple(input_.cuda(device=self.device_ordinal, non_blocking=self.asynchronous) for input_ in inputs) elif self.target_device == 'cpu': transferred = tuple(input_.cpu() for input_ in inputs) else: raise NotImplementedError return from_iterable(transferred) class OnDevice(nn.Module): """ Moves a module to a device. The advantage of using this over `torch.nn.Module.cuda` is that the inputs are transferred to the same device as the module, enabling easy model parallelism. """ def __init__(self, module, target_device, device_ordinal=None, asynchronous=False): """ Parameters ---------- module : torch.nn.Module Module to transfer to device. target_device : {'cuda', 'cpu'} The device to move `module` to. Must be either 'cuda' or 'cpu'. device_ordinal : int Ordinal of the GPU device if `target_device = 'cuda'`. asynchronous : bool Whether to use asynchronous transfers. """ super(OnDevice, self).__init__() # Validate arguments assert_(target_device in ['cpu', 'cuda'], "Target device must either be 'cpu' or 'cuda'.", DeviceError) if target_device == 'cpu': assert_(device_ordinal is None, "'device_ordinal' must be None if target_device is 'cpu'.", DeviceError) self.target_device = target_device self.device_ordinal = device_ordinal self.asynchronous = asynchronous # This is a no-op if module is already in the right device self.device_transfer = DeviceTransfer(self.target_device, device_ordinal=self.device_ordinal, asynchronous=self.asynchronous) self.module = self.transfer_module(module) def transfer_module(self, module): if self.target_device == 'cuda': return module.cuda(device_id=self.device_ordinal) elif self.target_device == 'cpu': return module.cpu() else: raise NotImplementedError def forward(self, *inputs): # Transfer inputs (no-op if they're already on the right device) transferred = to_iterable(self.device_transfer(*inputs)) output = self.module(*transferred) return output ================================================ FILE: inferno/extensions/layers/identity.py ================================================ import torch.nn as nn __all__ = ['identity'] _all = __all__ class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x ================================================ FILE: inferno/extensions/layers/normalization.py ================================================ import torch.nn as nn class BatchNormND(nn.Module): def __init__(self, dim, num_features, eps=1e-5, momentum=0.1, affine=True,track_running_stats=True): super(BatchNormND, self).__init__() assert dim in [1, 2, 3] self.bn = getattr(nn, 'BatchNorm{}d'.format(dim))(num_features=num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) def forward(self, x): return self.bn(x) ================================================ FILE: inferno/extensions/layers/reshape.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from ...utils.exceptions import assert_, ShapeError from ...utils import python_utils as pyu __all__ = ['View', 'AsMatrix', 'Flatten', 'As3D', 'As2D', 'Concatenate', 'Cat', 'ResizeAndConcatenate', 'PoolCat', 'GlobalMeanPooling', 'GlobalMaxPooling', 'Sum', 'SplitChannels','Squeeze', 'RemoveSingletonDimension'] _all = __all__ class View(nn.Module): def __init__(self, as_shape): super(View, self).__init__() self.as_shape = self.validate_as_shape(as_shape) def validate_as_shape(self, as_shape): assert all([isinstance(_s, int) or _s == 'x' for _s in as_shape]) all_int_indices = [_n for _n, _s in enumerate(as_shape) if isinstance(_s, int)] if all_int_indices: first_int_at_index = all_int_indices[0] assert all([isinstance(_s, int) for _s in as_shape[first_int_at_index:]]) return as_shape def forward(self, input): input_shape = list(input.size()) reshaped_shape = [_s if isinstance(_s, int) else input_shape[_n] for _n, _s in enumerate(self.as_shape)] output = input.view(*reshaped_shape) return output class AsMatrix(View): def __init__(self): super(AsMatrix, self).__init__(as_shape=['x', 'x']) class Flatten(View): def __init__(self): super(Flatten, self).__init__(as_shape=['x', -1]) class As3D(nn.Module): def __init__(self, channel_as_z=False, num_channels_or_num_z_slices=1): super(As3D, self).__init__() self.channel_as_z = channel_as_z self.num_channels_or_num_z_slices = num_channels_or_num_z_slices def forward(self, input): if input.dim() == 5: # If input is a batch of 3D volumes - return as is return input elif input.dim() == 4: # If input is a batch of 2D images, reshape b, c, _0, _1 = list(input.size()) assert_(c % self.num_channels_or_num_z_slices == 0, "Number of channels of the 4D image tensor (= {}) must be " "divisible by the set number of channels or number of z slices " "of the 5D volume tensor (= {})." .format(c, self.num_channels_or_num_z_slices), ShapeError) c //= self.num_channels_or_num_z_slices if self.channel_as_z: # Move channel axis to z return input.view(b, self.num_channels_or_num_z_slices, c, _0, _1) else: # Keep channel axis where it is, but add a singleton dimension for z return input.view(b, c, self.num_channels_or_num_z_slices, _0, _1) elif input.dim() == 2: # We have a matrix which we wish to turn to a 3D batch b, c = list(input.size()) return input.view(b, c, 1, 1, 1) else: raise NotImplementedError class As2D(nn.Module): def __init__(self, z_as_channel=True): super(As2D, self).__init__() self.z_as_channel = z_as_channel def forward(self, input): if input.dim() == 5: b, c, _0, _1, _2 = list(input.size()) if not self.z_as_channel: assert _0 == 1 # Reshape return input.view(b, c * _0, _1, _2) elif input.dim() == 4: # Nothing to do here - input is already 2D return input elif input.dim() == 2: # We make singleton dimensions b, c = list(input.size()) return input.view(b, c, 1, 1) class Concatenate(nn.Module): """Concatenate input tensors along a specified dimension.""" def __init__(self, dim=1): super(Concatenate, self).__init__() self.dim = dim def forward(self, *inputs): return torch.cat(inputs, dim=self.dim) class ResizeAndConcatenate(nn.Module): """ Resize input tensors spatially (to a specified target size) before concatenating them along the a given dim (channel, i.e. 1 by default). The down-sampling mode can be specified ('average' or 'max'), but the up-sampling is always 'nearest'. """ POOL_MODE_MAPPING = {'avg': 'avg', 'average': 'avg', 'mean': 'avg', 'max': 'max'} def __init__(self, target_size, pool_mode='average', dim=1): super(ResizeAndConcatenate, self).__init__() self.target_size = target_size assert_(pool_mode in self.POOL_MODE_MAPPING.keys(), "`pool_mode` must be one of {}, got {} instead." .format(self.POOL_MODE_MAPPING.keys(), pool_mode), ValueError) self.pool_mode = self.POOL_MODE_MAPPING.get(pool_mode) self.dim = dim def forward(self, *inputs): dim = inputs[0].dim() assert_(dim in [4, 5], 'Input tensors must either be 4 or 5 ' 'dimensional, but inputs[0] is {}D.'.format(dim), ShapeError) # Get resize function spatial_dim = {4: 2, 5: 3}[dim] resize_function = getattr(F, 'adaptive_{}_pool{}d'.format(self.pool_mode, spatial_dim)) target_size = pyu.as_tuple_of_len(self.target_size, spatial_dim) # Do the resizing resized_inputs = [] for input_num, input in enumerate(inputs): # Make sure the dim checks out assert_(input.dim() == dim, "Expected inputs[{}] to be a {}D tensor, got a {}D " "tensor instead.".format(input_num, dim, input.dim()), ShapeError) resized_inputs.append(resize_function(input, target_size)) # Concatenate along the channel axis if len(resized_inputs) > 1: concatenated = torch.cat(tuple(resized_inputs), self.dim) else: concatenated = resized_inputs[0] # Done return concatenated class Cat(Concatenate): """An alias for `Concatenate`. Hey, everyone knows who Cat is.""" pass class PoolCat(ResizeAndConcatenate): """Alias for `ResizeAndConcatenate`, just to annoy snarky web developers.""" pass class GlobalMeanPooling(ResizeAndConcatenate): """Global mean pooling layer.""" def __init__(self): super(GlobalMeanPooling, self).__init__((1, 1), 'average') class GlobalMaxPooling(ResizeAndConcatenate): """Global max pooling layer.""" def __init__(self): super(GlobalMaxPooling, self).__init__((1, 1), 'max') class Sum(nn.Module): """Sum all inputs.""" def forward(self, *inputs): return torch.stack(inputs, dim=0).sum(0) class SplitChannels(nn.Module): """Split input at a given index along the channel axis.""" def __init__(self, channel_index): super(SplitChannels, self).__init__() self.channel_index = channel_index def forward(self, input): if isinstance(self.channel_index, int): split_location = self.channel_index elif self.channel_index == 'half': split_location = input.size(1) // 2 else: raise NotImplementedError assert split_location < input.size(1) split_0 = input[:, 0:split_location, ...] split_1 = input[:, split_location:, ...] return split_0, split_1 class Squeeze(nn.Module): def __init__(self): super(Squeeze, self).__init__() def forward(self, x): return x.squeeze() class RemoveSingletonDimension(nn.Module): def __init__(self, dim=1): super(RemoveSingletonDimension, self).__init__() self.dim = 1 def forward(self, x): size = list(x.size()) if size[self.dim] != 1: raise RuntimeError("RemoveSingletonDimension expects a single channel at dim %d, shape=%s"%(self.dim,str(size))) slicing = [] for s in size: slicing.append(slice(0, s)) slicing[self.dim] = 0 return x[slicing] ================================================ FILE: inferno/extensions/layers/sampling.py ================================================ import torch.nn as nn __all__ = ['AnisotropicUpsample', 'AnisotropicPool', 'Upsample', 'AnisotropicUpsample2D', 'AnisotropicPool2D'] # torch is deprecating nn.Upsample in favor of nn.functional.interpolate # we wrap interpolate here to still use Upsample as class class Upsample(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None): self.size = size self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners super(Upsample, self).__init__() # interpolate was only introduced in torch 0.4.1 for backward compatibility # we check if we have the attribute here and fall back to Upsample otherwise if hasattr(nn.functional, 'interpolate'): self.have_interpolate = True else: self.have_interpolate = False self.sampler = nn.Upsample(size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) def forward(self, input): if self.have_interpolate: return nn.functional.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners) else: return self.sampler(input) class AnisotropicUpsample(nn.Module): def __init__(self, scale_factor): super(AnisotropicUpsample, self).__init__() self.upsampler = Upsample(scale_factor=scale_factor) def forward(self, input): # input is 3D of shape NCDHW N, C, D, H, W = input.size() # Fold C and D axes in one folded = input.view(N, C * D, H, W) # Upsample upsampled = self.upsampler(folded) # Unfold out the C and D axes unfolded = upsampled.view(N, C, D, self.upsampler.scale_factor * H, self.upsampler.scale_factor * W) # Done return unfolded class AnisotropicPool(nn.MaxPool3d): def __init__(self, downscale_factor): ds = downscale_factor super(AnisotropicPool, self).__init__(kernel_size=(1, ds + 1, ds + 1), stride=(1, ds, ds), padding=(0, 1, 1)) class AnisotropicUpsample2D(nn.Module): def __init__(self, scale_factor): super(AnisotropicUpsample2D, self).__init__() self.upsampler = nn.Upsample(scale_factor=scale_factor) def forward(self, input): # input is 2D of shape NCDW (or NCDH, egal) N, C, D, W = input.size() # Fold C and D axes in one folded = input.view(N, C * D, W) # Upsample upsampled = self.upsampler(folded) # Unfold out the C and D axes unfolded = upsampled.view(N, C, D, self.upsampler.scale_factor * W) # Done return unfolded class AnisotropicPool2D(nn.MaxPool2d): def __init__(self, downscale_factor): ds = downscale_factor super(AnisotropicPool2D, self).__init__(kernel_size=(1, ds + 1), stride=(1, ds), padding=(0, 1)) ================================================ FILE: inferno/extensions/metrics/__init__.py ================================================ from .categorical import * from .arand import * ================================================ FILE: inferno/extensions/metrics/arand.py ================================================ from .base import Metric import numpy as np import scipy.sparse as sparse import logging class ArandScore(Metric): """Arand Score, as defined in [1]. References ---------- [1]: http://journal.frontiersin.org/article/10.3389/fnana.2015.00142/full#h3 """ def __init__(self, average_slices=True): self.average_slices = average_slices # compute the arand score for a prediction target pair def _arand_for_tensor(self, prediction, target): # check if we need to average over slices average_slices = self.average_slices and prediction.ndim == 3 score_is_invalid = False # average the rand score over 3d slices if average_slices: # average the arand values over the 3d slices evaluation_values = [adapted_rand(pred, targ) for pred, targ in zip(prediction, target)] # check if the score is invalid if all(ev_val is None for ev_val in evaluation_values): score_is_invalid = True score = 0 else: score = np.mean([eval_val[0] for eval_val in evaluation_values if eval_val is not None]) # compute rand score on whole image / volume else: score = adapted_rand(prediction, target) # check if the score is invalid if score is None: score_is_invalid = True score = 0 else: score = score[0] if score_is_invalid: logger = logging.getLogger(__name__) logger.warning("All slices were invalid, returning worst possible score") return score def forward(self, prediction, target): assert(prediction.shape == target.shape), "%s, %s" % (str(prediction.shape), str(target.shape)) assert prediction.shape[1] == 1, "Expect singleton channel axis" prediction = prediction.cpu().numpy() target = target.cpu().numpy() ndim = prediction.ndim assert ndim in (4, 5), "Expect 2 or 3d input with additional batch and channel axis" # return the average arand error over the batches return np.mean([self._arand_for_tensor(pred[0], targ[0]) for pred, targ in zip(prediction, target)]) class ArandError(ArandScore): """Arand Error = 1 - """ def __init__(self, **super_kwargs): super(ArandError, self).__init__(**super_kwargs) def forward(self, prediction, target): return 1. - super(ArandError, self).forward(prediction, target) # Evaluation code courtesy of Juan Nunez-Iglesias, taken from # https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py def adapted_rand(seg, gt): """Compute Adapted Rand error as defined by the SNEMI3D contest [1] Formula is given as 1 - the maximal F-score of the Rand index (excluding the zero component of the original labels). Adapted from the SNEMI3D MATLAB script, hence the strange style. Parameters ---------- seg : np.ndarray the segmentation to score, where each value is the label at that point gt : np.ndarray, same shape as seg the groundtruth to score against, where each value is a label Returns ------- are : float The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$, where $p$ and $r$ are the precision and recall described below. prec : float, optional The adapted Rand precision. rec : float, optional The adapted Rand recall. References ---------- [1]: http://brainiac2.mit.edu/SNEMI3D/evaluation """ assert seg.shape == gt.shape, "%s, %s" % (str(seg.shape), str(gt.shape)) logger = logging.getLogger(__name__) if np.any(seg == 0): logger.debug("Zeros in segmentation, treating as background.") if np.any(gt == 0): logger.debug("Zeros in ground truth, 0's will be ignored.") seg_zeros = np.all(seg == 0) gt_zeros = np.all(gt == 0) # return None if either gt or segmentation are all zeros logger.debug("Either segmentation or groundtruth are all zeros, returning None.") if seg_zeros or gt_zeros: return None # segA is truth, segB is query segA = np.ravel(gt) segB = np.ravel(seg) # mask to foreground in A mask = (segA > 0) segA = segA[mask] segB = segB[mask] # number of nonzero pixels in original segA n = segA.size n_labels_A = int(np.amax(segA)) + 1 n_labels_B = int(np.amax(segB)) + 1 ones_data = np.ones(n) p_ij = sparse.csr_matrix((ones_data, (segA.ravel(), segB.ravel())), shape=(n_labels_A, n_labels_B), dtype=np.uint64) # In the paper where adapted rand is proposed, they treat each background # pixel in segB as a different value (i.e., unique label for each pixel). # To do this, we sum them differently than others # ind (label_gt, label_seg), so ignore 0 seg labels B_nonzero = p_ij[:, 1:] B_zero = p_ij[:, 0] # this is a count num_B_zero = B_zero.sum() # sum of the joint distribution # separate sum of B>0 and B=0 parts sum_p_ij = (B_nonzero).power(2).sum() + num_B_zero # these are marginal probabilities # sum over all seg labels overlapping one gt label (except 0 labels) a_i = p_ij.sum(1) b_i = B_nonzero.sum(0) sum_a = np.power(a_i, 2).sum() sum_b = np.power(b_i, 2).sum() + num_B_zero precision = float(sum_p_ij) / sum_b recall = float(sum_p_ij) / sum_a f_score = 2.0 * precision * recall / (precision + recall) return f_score, precision, recall ================================================ FILE: inferno/extensions/metrics/base.py ================================================ class Metric(object): def forward(self, *args, **kwargs): raise NotImplementedError def __call__(self, prediction, target, **kwargs): # We might have listlike predictions (e.g. multi-scale) # If so, we evaluate the metric on the first prediction, # which should be at the original scale if isinstance(prediction, (list, tuple)): prediction = prediction[0] # same is true for the target if isinstance(target, (list, tuple)): target = target[0] # Make sure prediction and target live on the same device. # If they don't, move target to the right device. if not prediction.is_cuda: # Move to CPU target = target.cpu() else: # Find device to move to device_ordinal = prediction.get_device() target = target.cuda(device_ordinal) return self.forward(prediction, target, **kwargs) ================================================ FILE: inferno/extensions/metrics/categorical.py ================================================ import torch from .base import Metric from ...utils.torch_utils import flatten_samples, is_label_tensor from ...utils.exceptions import assert_, DTypeError, ShapeError class CategoricalError(Metric): """Categorical error.""" def __init__(self, aggregation_mode='mean'): assert aggregation_mode in ['mean', 'sum'] self.aggregation_mode = aggregation_mode def forward(self, prediction, target): # Check if prediction is binary or not is_binary = len(prediction.size()) == 1 or prediction.size(1) == 1 if len(target.size()) > 1: target = target.squeeze(1) assert len(target.size()) == 1 if is_binary: # Binary classification prediction = prediction > 0.5 incorrect = prediction.type_as(target).ne(target).float() if self.aggregation_mode == 'mean': return incorrect.mean() else: return incorrect.sum() else: # Multiclass classificiation _, predicted_class = torch.max(prediction, 1) if predicted_class.dim() == prediction.dim(): # Support for Pytorch 0.1.12 predicted_class = predicted_class.squeeze(1) incorrect = predicted_class.type_as(target).ne(target).float() if self.aggregation_mode == 'mean': return incorrect.mean() else: return incorrect.sum() class IOU(Metric): """Intersection over Union. """ def __init__(self, ignore_class=None, sharpen_prediction=False, eps=1e-6): super(IOU, self).__init__() self.eps = eps self.ignore_class = ignore_class self.sharpen_prediction = sharpen_prediction def forward(self, prediction, target): # Assume that is one of: # prediction.shape = (N, C, H, W) # prediction.shape = (N, C, D, H, W) # prediction.shape = (N, C) # The corresponding target shapes are either: # target.shape = (N, H, W) # target.shape = (N, D, H, W) # target.shape = (N,) # Or: # target.shape = (N, C, H, W) # target.shape = (N, C, D, H, W) # target.shape = (N, C) # First, reshape prediction to (C, -1) flattened_prediction = flatten_samples(prediction) # Take measurements num_classes, num_samples = flattened_prediction.size() # We need to figure out if the target is a int label tensor or a onehot tensor. # The former always has one dimension less, so if target.dim() == (prediction.dim() - 1): # Labels, we need to go one hot # Make sure it's a label assert_(is_label_tensor(target), "Target must be a label tensor (of dtype long) if it has one " "dimension less than the prediction.", DTypeError) # Reshape target to (1, -1) for it to work with scatter flattened_target = target.view(1, -1) # Convert target to onehot with shape (C, -1) # Make sure the target is consistent assert_(target.max() < num_classes) onehot_targets = flattened_prediction \ .new(num_classes, num_samples) \ .zero_() \ .scatter_(0, flattened_target, 1) elif target.dim() == prediction.dim(): # Onehot, nothing to do except flatten onehot_targets = flatten_samples(target) else: raise ShapeError("Target must have the same number of dimensions as the " "prediction, or one less. Got target.dim() = {} but " "prediction.dim() = {}.".format(target.dim(), prediction.dim())) # Cast onehot_targets to float if required (this is a no-op if it's already float) onehot_targets = onehot_targets.float() # Sharpen prediction if required to. Sharpening in this sense means to replace # the max predicted probability with 1. if self.sharpen_prediction: _, predicted_classes = torch.max(flattened_prediction, 0) # Case for pytorch 0.2, where predicted_classes is (N,) instead of (1, N) if predicted_classes.dim() == 1: predicted_classes = predicted_classes.view(1, -1) # Scatter flattened_prediction = flattened_prediction\ .new(num_classes, num_samples).zero_().scatter_(0, predicted_classes, 1) # Now to compute the IOU = (a * b).sum()/(a**2 + b**2 - a * b).sum() # We sum over all samples to obtain a classwise iou numerator = (flattened_prediction * onehot_targets).sum(-1) denominator = \ flattened_prediction.sub_(onehot_targets).pow_(2).clamp_(min=self.eps).sum(-1) + \ numerator classwise_iou = numerator.div_(denominator) # If we're ignoring a class, don't count its contribution to the mean if self.ignore_class is not None: ignore_class = self.ignore_class \ if self.ignore_class != -1 else onehot_targets.size(0) - 1 assert_(ignore_class < onehot_targets.size(0), "`ignore_class` = {} must be at least one less than the number " "of classes = {}.".format(ignore_class, onehot_targets.size(0)), ValueError) num_classes = onehot_targets.size(0) dont_ignore_class = list(range(num_classes)) dont_ignore_class.pop(ignore_class) if classwise_iou.is_cuda: dont_ignore_class = \ torch.LongTensor(dont_ignore_class).cuda(classwise_iou.get_device()) else: dont_ignore_class = torch.LongTensor(dont_ignore_class) iou = classwise_iou[dont_ignore_class].mean() else: iou = classwise_iou.mean() return iou class NegativeIOU(IOU): def forward(self, prediction, target): return -1 * super(NegativeIOU, self).forward(prediction, target) ================================================ FILE: inferno/extensions/metrics/cremi_score.py ================================================ import numpy as np from .voi import voi from .arand import adapted_rand # TODO build metrics object def cremi_metrics(seg, gt, no_seg_ignore=True): if no_seg_ignore: if 0 in seg: seg += 1 vi_s, vi_m = voi(seg, gt) rand = 1. - adapted_rand(seg, gt)[0] cs = np.sqrt((vi_s + vi_m) * rand) return cs, vi_s, vi_m, rand ================================================ FILE: inferno/extensions/metrics/voi.py ================================================ from .base import Metric import numpy as np import scipy.sparse as sparse class VoiScore(Metric): """ Computes a score based on the variation of information according to [1]. References ---------- [1] Meila, M. (2007). Comparing clusterings - an information based distance. Journal of Multivariate Analysis 98, 873-895. """ def forward(self, prediction, target): assert(len(prediction) == len(target)) segmentation = prediction.cpu().numpy() target = target.cpu().numpy() return np.mean([sum(voi(segmentation[i], target[i])) for i in range(len(prediction))]) # Copied from `cremi-python` # https://github.com/cremi/cremi_python/blob/master/cremi/evaluation/voi.py # Evaluation code courtesy of Juan Nunez-Iglesias, taken from # https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py def voi(seg, gt, ignore_reconstruction=[], ignore_groundtruth=[0]): """Return the conditional entropies of the variation of information metric. [1] Let X be a seg, and Y a ground truth labelling. The variation of information between the two is the sum of two conditional entropies: VI(X, Y) = H(X|Y) + H(Y|X). The first one, H(X|Y), is a measure of oversegmentation, the second one, H(Y|X), a measure of undersegmentation. These measures are referred to as the variation of information split or merge error, respectively. Parameters ---------- seg : np.ndarray, int type, arbitrary shape A candidate segmentation. gt : np.ndarray, int type, same shape as `seg` The ground truth segmentation. ignore_seg, ignore_gt : list of int, optional Any points having a label in this list are ignored in the evaluation. By default, only the label 0 in the ground truth will be ignored. Returns ------- (split, merge) : float The variation of information split and merge error, i.e., H(X|Y) and H(Y|X) References ---------- [1] Meila, M. (2007). Comparing clusterings - an information based distance. Journal of Multivariate Analysis 98, 873-895. """ hyxg, hxgy = split_vi(seg, gt, ignore_reconstruction, ignore_groundtruth) return hxgy, hyxg def split_vi(x, y=None, ignore_x=[0], ignore_y=[0]): """Return the symmetric conditional entropies associated with the VI. The variation of information is defined as VI(X,Y) = H(X|Y) + H(Y|X). If Y is the ground-truth segmentation, then H(Y|X) can be interpreted as the amount of under-segmentation of Y and H(X|Y) is then the amount of over-segmentation. In other words, a perfect over-segmentation will have H(Y|X)=0 and a perfect under-segmentation will have H(X|Y)=0. If y is None, x is assumed to be a contingency table. Parameters ---------- x : np.ndarray Label field (int type) or contingency table (float). `x` is interpreted as a contingency table (summing to 1.0) if and only if `y` is not provided. y : np.ndarray of int, same shape as x, optional A label field to compare to `x`. ignore_x, ignore_y : list of int, optional Any points having a label in this list are ignored in the evaluation. Ignore 0-labeled points by default. Returns ------- sv : np.ndarray of float, shape (2,) The conditional entropies of Y|X and X|Y. See Also -------- vi """ _, _, _, hxgy, hygx, _, _ = vi_tables(x, y, ignore_x, ignore_y) # false merges, false splits return np.array([hygx.sum(), hxgy.sum()]) def vi_tables(x, y=None, ignore_x=[0], ignore_y=[0]): """Return probability tables used for calculating VI. If y is None, x is assumed to be a contingency table. Parameters ---------- x, y : np.ndarray Either x and y are provided as equal-shaped np.ndarray label fields (int type), or y is not provided and x is a contingency table (sparse.csc_matrix) that may or may not sum to 1. ignore_x, ignore_y : list of int, optional Rows and columns (respectively) to ignore in the contingency table. These are labels that are not counted when evaluating VI. Returns ------- pxy : sparse.csc_matrix of float The normalized contingency table. px, py, hxgy, hygx, lpygx, lpxgy : np.ndarray of float The proportions of each label in `x` and `y` (`px`, `py`), the per-segment conditional entropies of `x` given `y` and vice-versa, the per-segment conditional probability p log p. """ if y is not None: pxy = contingency_table(x, y, ignore_x, ignore_y) else: cont = x total = float(cont.sum()) # normalize, since it is an identity op if already done pxy = cont / total # Calculate probabilities px = np.array(pxy.sum(axis=1)).ravel() py = np.array(pxy.sum(axis=0)).ravel() # Remove zero rows/cols nzx = px.nonzero()[0] nzy = py.nonzero()[0] nzpx = px[nzx] nzpy = py[nzy] nzpxy = pxy[nzx, :][:, nzy] # Calculate log conditional probabilities and entropies lpygx = np.zeros(np.shape(px)) lpygx[nzx] = xlogx(divide_rows(nzpxy, nzpx)).sum(axis=1).squeeze() # \sum_x{p_{y|x} \log{p_{y|x}}} hygx = -(px * lpygx) # \sum_x{p_x H(Y|X=x)} = H(Y|X) lpxgy = np.zeros(np.shape(py)) lpxgy[nzy] = xlogx(divide_columns(nzpxy, nzpy)).sum(axis=0) hxgy = -(py * lpxgy) return [pxy] + list(map(np.asarray, [px, py, hxgy, hygx, lpygx, lpxgy])) def contingency_table(seg, gt, ignore_seg=[0], ignore_gt=[0], norm=True): """Return the contingency table for all regions in matched segmentations. Parameters ---------- seg : np.ndarray, int type, arbitrary shape A candidate segmentation. gt : np.ndarray, int type, same shape as `seg` The ground truth segmentation. ignore_seg : list of int, optional Values to ignore in `seg`. Voxels in `seg` having a value in this list will not contribute to the contingency table. (default: [0]) ignore_gt : list of int, optional Values to ignore in `gt`. Voxels in `gt` having a value in this list will not contribute to the contingency table. (default: [0]) norm : bool, optional Whether to normalize the table so that it sums to 1. Returns ------- cont : scipy.sparse.csc_matrix A contingency table. `cont[i, j]` will equal the number of voxels labeled `i` in `seg` and `j` in `gt`. (Or the proportion of such voxels if `norm=True`.) """ segr = seg.ravel() gtr = gt.ravel() ignored = np.zeros(segr.shape, np.bool) data = np.ones(len(gtr)) for i in ignore_seg: ignored[segr == i] = True for j in ignore_gt: ignored[gtr == j] = True data[ignored] = 0 cont = sparse.coo_matrix((data, (segr, gtr))).tocsc() if norm: cont /= float(cont.sum()) return cont def divide_columns(matrix, row, in_place=False): """Divide each column of `matrix` by the corresponding element in `row`. The result is as follows: out[i, j] = matrix[i, j] / row[j] Parameters ---------- matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N) The input matrix. column : a 1D np.ndarray, shape (N,) The row dividing `matrix`. in_place : bool (optional, default False) Do the computation in-place. Returns ------- out : same type as `matrix` The result of the row-wise division. """ if in_place: out = matrix else: out = matrix.copy() if type(out) in [sparse.csc_matrix, sparse.csr_matrix]: if type(out) == sparse.csc_matrix: convert_to_csc = True out = out.tocsr() else: convert_to_csc = False row_repeated = np.take(row, out.indices) nz = out.data.nonzero() out.data[nz] /= row_repeated[nz] if convert_to_csc: out = out.tocsc() else: out /= row[np.newaxis, :] return out def divide_rows(matrix, column, in_place=False): """Divide each row of `matrix` by the corresponding element in `column`. The result is as follows: out[i, j] = matrix[i, j] / column[i] Parameters ---------- matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N) The input matrix. column : a 1D np.ndarray, shape (M,) The column dividing `matrix`. in_place : bool (optional, default False) Do the computation in-place. Returns ------- out : same type as `matrix` The result of the row-wise division. """ if in_place: out = matrix else: out = matrix.copy() if type(out) in [sparse.csc_matrix, sparse.csr_matrix]: if type(out) == sparse.csr_matrix: convert_to_csr = True out = out.tocsc() else: convert_to_csr = False column_repeated = np.take(column, out.indices) nz = out.data.nonzero() out.data[nz] /= column_repeated[nz] if convert_to_csr: out = out.tocsr() else: out /= column[:, np.newaxis] return out def xlogx(x, out=None, in_place=False): """Compute x * log_2(x). We define 0 * log_2(0) = 0 Parameters ---------- x : np.ndarray or scipy.sparse.csc_matrix or csr_matrix The input array. out : same type as x (optional) If provided, use this array/matrix for the result. in_place : bool (optional, default False) Operate directly on x. Returns ------- y : same type as x Result of x * log_2(x). """ if in_place: y = x elif out is None: y = x.copy() else: y = out if type(y) in [sparse.csc_matrix, sparse.csr_matrix]: z = y.data else: z = y nz = z.nonzero() z[nz] *= np.log2(z[nz]) return y ================================================ FILE: inferno/extensions/models/__init__.py ================================================ from .unet import UNet, UNetBase from .res_unet import ResBlockUNet ================================================ FILE: inferno/extensions/models/res_unet.py ================================================ import torch import torch.nn as nn from ..layers.convolutional import ConvActivation from .unet import UNetBase from ...utils.python_utils import require_dict_kwargs __all__ = ['ResBlockUNet'] _all = __all__ # We only use this for the u-net implementation here # in favor of less code duplication it might be a # good ideat to replace this with 'ResidualBlock' from layers.convolutional_blocks class _ResBlockBase(nn.Module): def __init__(self, in_channels, out_channels, dim, size=2, force_skip_op=False, activated=True): super(_ResBlockBase, self).__init__() self.in_channels = int(in_channels) self.out_channels = int(out_channels) self.size = int(size) self.activated = bool(activated) self.force_skip_op = bool(force_skip_op) self.dim = int(dim) if self.in_channels != self.out_channels or self.force_skip_op: self.activated_skip_op = self.activated_skip_op_factory(in_channels=self.in_channels, out_channels=self.out_channels) conv_ops = [] activation_ops = [] for i in range(self.size): # the convolutions if i == 0: op = self.nonactivated_conv_op_factory(in_channels=self.out_channels, out_channels=self.out_channels, index=i) else: op = self.nonactivated_conv_op_factory(in_channels=self.out_channels, out_channels=self.out_channels, index=i) conv_ops.append(op) # the activations if i < self.size or self.activated: activation_ops.append(self.activation_op_factory(index=i)) self.conv_ops = nn.ModuleList(conv_ops) self.activation_ops = nn.ModuleList(activation_ops) def activated_skip_op_factory(self, in_channels, out_channels): raise NotImplementedError("activated_skip_op_factory need to be implemented by deriving class") def nonactivated_conv_op_factory(self, in_channels, out_channels, index): raise NotImplementedError("conv_op_factory need to be implemented by deriving class") def activation_op_factory(self, index): return nn.ReLU() def forward(self, input): if input.size(1) != self.in_channels: raise RuntimeError("wrong number of channels: expected %d, got %d"% (self.in_channels, input.size(1))) if input.dim() != self.dim + 2: raise RuntimeError("wrong number of dim: expected %d, got %d"% (self.dim+2, input.dim())) if self.in_channels != self.out_channels or self.force_skip_op: skip_res = self.activated_skip_op(input) else: skip_res = input assert skip_res.size(1) == self.out_channels res = skip_res for i in range(self.size): res = self.conv_ops[i](res) assert res.size(1) == self.out_channels if i + 1 < self.size: res = self.activation_ops[i](res) non_activated = skip_res + res if self.activated: return self.activation_ops[-1](non_activated) else: return non_activated class _ResBlock(_ResBlockBase): def __init__(self, in_channels, out_channels, dim, size=2, activated=True, activation='ReLU', batchnorm=True, force_skip_op=False, conv_kwargs=None): # trick to store nn-module before call of super # => we put it in a list if isinstance(activation, str): self.activation_op = [getattr(torch.nn, activation)()] elif isinstance(activation, nn.Module): self.activation_op = [activation] else: raise RuntimeError("activation must be a striong or a torch.nn.Module") # keywords for conv if conv_kwargs is None: conv_kwargs = dict( kernel_size=3, dim=dim, activation=None, stride=1, dilation=1, groups=None, depthwise=False, bias=True, deconv=False, initialization=None ) elif isinstance(conv_kwargs, dict): conv_kwargs['activation'] = None else: raise RuntimeError("conv_kwargs must be either None or a dict") self.conv_kwargs = conv_kwargs self.dim = dim self.batchnorm = batchnorm self.conv_1x1_kwargs = dict(kernel_size=1, dim=dim, activation=None, stride=1, dilation=1, groups=None, depthwise=False, bias=True, deconv=False, initialization=None) super(_ResBlock, self).__init__(in_channels=in_channels, out_channels=out_channels, dim=dim, size=size, force_skip_op=force_skip_op, activated=activated) def activated_skip_op_factory(self, in_channels, out_channels): conv_op = ConvActivation(in_channels=in_channels, out_channels=out_channels, **self.conv_1x1_kwargs) if self.batchnorm: batchnorm_op = self.batchnorm_op_factory(in_channels=out_channels) return torch.nn.Sequential(conv_op, batchnorm_op, self.activation_op[0]) else: return torch.nn.Sequential(conv_op, self.activation_op[0]) def nonactivated_conv_op_factory(self, in_channels, out_channels, index): conv_op = ConvActivation(in_channels=in_channels, out_channels=out_channels, **self.conv_kwargs) if self.batchnorm: batchnorm_op = self.batchnorm_op_factory(in_channels=out_channels) return torch.nn.Sequential(conv_op, batchnorm_op) else: return conv_op def activation_op_factory(self, index): return self.activation_op[0] def batchnorm_op_factory(self, in_channels): bn_cls_name = 'BatchNorm{}d'.format(int(self.dim)) bn_op_cls = getattr(torch.nn, bn_cls_name) return bn_op_cls(in_channels) # TODO not sure how to handle out-channels properly. # For now, we just force the corrcect number in the last decoder layer class ResBlockUNet(UNetBase): """TODO. ACCC Attributes: activated (TYPE): Description dim (TYPE): Description res_block_kwargs (TYPE): Description side_out_parts (TYPE): Description unet_kwargs (TYPE): Description """ def __init__(self, in_channels, dim, out_channels, unet_kwargs=None, res_block_kwargs=None, activated=True, side_out_parts=None): self.dim = dim self.unet_kwargs = require_dict_kwargs(unet_kwargs, "unet_kwargs must be a dict or None") self.res_block_kwargs = require_dict_kwargs(res_block_kwargs, "res_block_kwargs must be a dict or None") self.activated = activated if isinstance(side_out_parts, str): self.side_out_parts = set([side_out_parts]) elif isinstance(side_out_parts, (tuple,list)): self.side_out_parts = set(side_out_parts) else: self.side_out_parts = set() super(ResBlockUNet, self).__init__(in_channels=in_channels, out_channels=out_channels, dim=dim, **self.unet_kwargs) def conv_op_factory(self, in_channels, out_channels, part, index): # is this the very last convolutional block? very_last = (part == 'up' and index == 0) # should the residual block be activated? activated = not very_last or self.activated # should the output be part of the overall # return-list in the forward pass of the UNet use_as_output = part in self.side_out_parts # residual block used within the UNet return _ResBlock(in_channels=in_channels, out_channels=out_channels, dim=self.dim, activated=activated, **self.res_block_kwargs), use_as_output ================================================ FILE: inferno/extensions/models/unet.py ================================================ import torch import torch.nn as nn from ..layers.identity import Identity from ..layers.convolutional import ConvELU2D, ConvELU3D, Conv2D, Conv3D from ..layers.sampling import Upsample as InfernoUpsample from ...utils.math_utils import max_allowed_ds_steps __all__ = ['UNetBase', 'UNet', 'ResBlockUNet'] _all = __all__ class UNetBase(nn.Module): """ Base class for implementing UNets. The depth and dimension of the UNet is flexible. The deriving classes must implement `conv_op_factory` and can implement `upsample_op_factory` and `downsample_op_factory`. Attributes: in_channels (int): Number of input channels. dim (int): Spatial dimension of data (must be 2 or 3). out_channels (int): Number of output channels. Set to None by default, which sets the number of out channels to the number of input channels to preserve symmetry of feature channels (default: None). depth (int): How many down-sampling / up-sampling steps shall be performed (default: 3). gain (int): Multiplicative increase of channels while going down in the UNet. The same factor is used to decrease the number of channels while going up in the UNet (default: 2). residual (bool): If residual is true, the output of the down-streams are added to the up-stream results. Otherwise the results are concatenated (default: False). """ def __init__(self, in_channels, dim, out_channels=None, depth=3, gain=2, residual=False, upsample_mode=None, p_dropout=None): super(UNetBase, self).__init__() # early sanity check if dim not in [2, 3]: raise RuntimeError("UNetBase is only implemented for 2D and 3D") # settings related members self.in_channels = int(in_channels) self.dim = int(dim) self.out_channels = self.in_channels if out_channels is\ None else int(out_channels) self.depth = int(depth) self.gain = int(gain) self.residual = bool(residual) self.p_dropout = p_dropout # members to remember what to store as side output self._store_conv_down = [] self._store_conv_bottom = False self._store_conv_up = [] # number of channels per side output self.n_channels_per_output = [] # members to hold actual nn.Modules / nn.ModuleLists self._pre_conv_down_ops = None self._post_conv_down_ops = None self._conv_down_ops = None self._pre_conv_up_ops = None self._post_conv_up_ops = None self._conv_up_ops = None self._upsample_ops = None self._downsample_ops = None self._pre_conv_bottom_ops = None self._post_conv_bottom_ops = None self._conv_bottom_op = None # upsample kwargs self._upsample_kwargs = self._make_upsample_kwargs(upsample_mode=upsample_mode) ######################################## # default dropout ######################################## if self.p_dropout is not None: self.use_dropout = True if self.dim == 2 : self._channel_dropout_op = self.torch.nn.Dropout2d(p=float(self.p_dropout), inplace=False) else: self._channel_dropout_op = self.torch.nn.Dropout3d(p=float(self.p_dropout), inplace=False) else: self.use_dropout = False # down-stream convolution blocks self._init__downstream() # pooling / downsample operators self._downsample_ops = nn.ModuleList([ self.downsample_op_factory(i) for i in range(depth) ]) # upsample operators # we flip the index that is given as argument to index consistently in up and # downstream sampling factories self._upsample_ops = nn.ModuleList([ self.upsample_op_factory(depth - i - 1) for i in range(depth) ]) # bottom block of the unet self._init__bottom() # up-stream convolution blocks self._init__upstream() assert len(self.n_channels_per_output) == self._store_conv_down.count(True) + \ self._store_conv_up.count(True) + int(self._store_conv_bottom) def _get_num_channels(self, depth): assert depth > 0 return self.in_channels * self.gain**depth def _init__downstream(self): conv_down_ops = [] self._store_conv_down = [] current_in_channels = self.in_channels for i in range(self.depth): out_channels = self._get_num_channels(i + 1) op, return_op_res = self.conv_op_factory(in_channels=current_in_channels, out_channels=out_channels, part='down', index=i) conv_down_ops.append(op) if return_op_res: self.n_channels_per_output.append(out_channels) self._store_conv_down.append(True) else: self._store_conv_down.append(False) # increase the number of channels current_in_channels = out_channels # store as proper torch ModuleList self._conv_down_ops = nn.ModuleList(conv_down_ops) return current_in_channels def _init__bottom(self): current_in_channels = self._get_num_channels(self.depth) factory_res = self.conv_op_factory(in_channels=current_in_channels, out_channels=current_in_channels, part='bottom', index=0) if isinstance(factory_res, tuple): self._conv_bottom_op, self._store_conv_bottom = factory_res if self._store_conv_bottom: self.n_channels_per_output.append(current_in_channels) else: self._conv_bottom_op = factory_res self._store_conv_bottom = False def _init__upstream(self): conv_up_ops = [] current_in_channels = self._get_num_channels(self.depth) for i in range(self.depth): # the number of out channels (set to self.out_channels for last decoder) out_channels = self.out_channels if i + 1 == self.depth else \ self._get_num_channels(self.depth - i - 1) # if not residual we concat which needs twice as many channels fac = 1 if self.residual else 2 # we flip the index that is given as argument to index consistently in up and # downstream conv factories op, return_op_res = self.conv_op_factory(in_channels=fac*current_in_channels, out_channels=out_channels, part='up', index=self.depth - i - 1) conv_up_ops.append(op) if return_op_res: self.n_channels_per_output.append(out_channels) self._store_conv_up.append(True) else: self._store_conv_up.append(False) # decrease the number of input_channels current_in_channels = out_channels # store as proper torch ModuleLis self._conv_up_ops = nn.ModuleList(conv_up_ops) # the last block needs to be stored in any case if not self._store_conv_up[-1]: self._store_conv_up[-1] = True self.n_channels_per_output.append(out_channels) def _make_upsample_kwargs(self, upsample_mode): """To avoid some waring from pytorch, and some missing implementations for the arguments need to be handle carefully in this helper functions Args: upsample_mode (str): users choice for upsampling interpolation style. """ if upsample_mode is None: if self.dim == 2: upsample_mode = 'bilinear' elif self.dim == 3: # upsample_mode = 'nearest' upsample_mode = 'trilinear' upsample_kwargs = dict(scale_factor=2, mode=upsample_mode) if upsample_mode in ('bilinear', 'trilinear'): upsample_kwargs['align_corners'] = False return upsample_kwargs def _forward_sanity_check(self, input): if isinstance(input, tuple): raise RuntimeError("tuples of tensors are not supported") shape = input.shape if shape[1] != self.in_channels: raise RuntimeError("wrong number of channels: expected %d, got %d"% (self.in_channels, input.size(1))) if input.dim() != self.dim + 2: raise RuntimeError("wrong number of dim: expected %d, got %d"% (self.dim+2, input.dim())) self._check_scaling(input) # override if model has different scaling def _check_scaling(self, input): shape = input.shape mx = max_allowed_ds_steps(shape=shape[2:2+self.dim], factor=2) if mx < self.depth: raise RuntimeError("cannot downsample %d times, with shape %s"% (self.depth, str(input.size())) ) def forward(self, input): # check if input is suitable self._forward_sanity_check(input=input) # collect all desired outputs side_out = [] # remember all conv-block results of the downward part # of the UNet down_res = [] ################################# # downwards part ################################# out = input for d in range(self.depth): out = self._conv_down_ops[d](out) #out = self.dropout down_res.append(out) if self._store_conv_down[d]: side_out.append(out) out = self._downsample_ops[d](out) ################################# # bottom part ################################# out = self._conv_bottom_op(out) if self._store_conv_bottom: side_out.append(out) ################################# # upward part ################################# down_res = list(reversed(down_res)) # <- eases indexing for d in range(self.depth): # upsample out = self._upsample_ops[d](out) # the result of the downward part a = down_res[d] # add or concat? if self.residual: out = a + out else: out = torch.cat([a, out], 1) # the convolutional block out = self._conv_up_ops[d](out) if self._store_conv_up[d]: side_out.append(out) # if len(side_out) == 1 we actually have no side output # just the main output if len(side_out) == 1: return side_out[0] else: return tuple(side_out) def downsample_op_factory(self, index): C = nn.MaxPool2d if self.dim == 2 else nn.MaxPool3d return C(kernel_size=2, stride=2) def upsample_op_factory(self, index):\ return InfernoUpsample(**self._upsample_kwargs) #return nn.Upsample(**self._upsample_kwargs) def conv_op_factory(self, in_channels, out_channels, part, index): raise NotImplementedError("conv_op_factory need to be implemented by deriving class") def _dropout(self, x): if self.use_dropout: return self._channel_dropout_op(x) else: return x # TODO implement function to load a pretrained unet class UNet(UNetBase): """ Default 2d / 3d U-Net implementation following: https://arxiv.org/abs/1505.04597 """ def __init__(self, in_channels, out_channels, dim, depth=4, initial_features=64, gain=2, final_activation=None, p_dropout=None): # convolutional types for inner convolutions and output convolutions self.default_conv = ConvELU2D if dim == 2 else ConvELU3D last_conv = Conv2D if dim == 2 else Conv3D # init the base class super(UNet, self).__init__(in_channels=initial_features, dim=dim, depth=depth, gain=gain, p_dropout=p_dropout) # initial conv layer to go from the number of input channels, which are defined by the data # (usually 1 or 3) to the initial number of feature maps self._initial_conv = self.default_conv(in_channels, initial_features, 3) # get the final output and activation activation if isinstance(final_activation, str): activation = getattr(nn, final_activation)() elif isinstance(final_activation, nn.Module): activation = final_activation elif final_activation is None: activation = None else: raise NotImplementedError("Activation of type %s is not supported" % type(final_activation)) # override the unet base attributes for out_channels self.out_channels = int(out_channels) if activation is None: self._output = last_conv(initial_features, self.out_channels, 1) else: self._output = nn.Sequential(last_conv(initial_features, self.out_channels, 1), activation) def forward(self, input): # TODO implement 2d from 3d input (see neurofire) x = self._initial_conv(input) x = super(UNet, self).forward(x) return self._output(x) def conv_op_factory(self, in_channels, out_channels, part, index): # is this the first convolutional block? first = (part == 'down' and index == 0) # if this is the first conv block, we just need # a single convolution, because we have the `_initial_conv` already if first: conv = self.default_conv(in_channels, out_channels, 3) else: conv = nn.Sequential(self.default_conv(in_channels, out_channels, 3), self.default_conv(out_channels, out_channels, 3)) return conv, False ================================================ FILE: inferno/extensions/optimizers/__init__.py ================================================ from .adam import Adam from .annealed_adam import AnnealedAdam from .ranger import Ranger, RangerQH, RangerVA ================================================ FILE: inferno/extensions/optimizers/adam.py ================================================ import math from torch.optim import Optimizer class Adam(Optimizer): """Implements Adam algorithm with the option of adding a L1 penalty. It has been proposed in `Adam: A Method for Stochastic Optimization`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, lambda_l1=0, weight_decay=0, **kwargs): defaults = dict(lr=lr, betas=betas, eps=eps, lambda_l1=lambda_l1, weight_decay=weight_decay, **kwargs) super(Adam, self).__init__(params, defaults) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = grad.new().resize_as_(grad).zero_() # Exponential moving average of squared gradient values state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 if group['lambda_l1'] != 0: grad.add_(group['lambda_l1'], p.data.sign()) if group['weight_decay'] != 0: grad.add_(group['weight_decay'], p.data) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 p.data.addcdiv_(-step_size, exp_avg, denom) return loss ================================================ FILE: inferno/extensions/optimizers/annealed_adam.py ================================================ from .adam import Adam class AnnealedAdam(Adam): """Implements Adam algorithm with learning rate annealing and optional L1 penalty. It has been proposed in `Adam: A Method for Stochastic Optimization`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) lambda_l1 (float, optional): L1 penalty (default: 0) weight_decay (float, optional): L2 penalty (weight decay) (default: 0) lr_decay(float, optional): decay learning rate by this factor after every step (default: 1.) .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, lambda_l1=0, weight_decay=0, lr_decay=1.): defaults = dict(lr=lr, betas=betas, eps=eps, lambda_l1=lambda_l1, weight_decay=weight_decay, lr_decay=lr_decay) super(AnnealedAdam, self).__init__(params, **defaults) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ # Do an optimization step super(AnnealedAdam, self).step(closure=closure) # Update learning rate for group in self.param_groups: group['lr'] *= group['lr_decay'] ================================================ FILE: inferno/extensions/optimizers/ranger.py ================================================ # easy support for additional ranger optimizers from # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer try: from ranger import Ranger, RangerVA, RangerQH except ImportError: Ranger = None RangerVA = None RangerQH = None ================================================ FILE: inferno/inferno.py ================================================ # -*- coding: utf-8 -*- """Main module.""" ================================================ FILE: inferno/io/__init__.py ================================================ from . import box from . import core from . import transform from . import volumetric ================================================ FILE: inferno/io/box/__init__.py ================================================ """Things that work out of the box. ;)""" from .camvid import CamVid, get_camvid_loaders from .cityscapes import Cityscapes, get_cityscapes_loaders from .cifar import get_cifar10_loaders, get_cifar100_loaders __all__ = [ 'CamVid','get_camvid_loaders', 'Cityscapes', 'get_cityscapes_loaders', 'get_cifar10_loaders','get_cifar100_loaders' ] ================================================ FILE: inferno/io/box/binary_blobs.py ================================================ import torch.utils.data as data import skimage.data import numpy from operator import mul from functools import reduce class BinaryBlobs(data.Dataset): def __init__(self, size=20, length=512, blob_size_fraction=0.1, n_dim=2, volume_fraction=0.5,split='train', uniform_noise_range=(-1.2, 1.2), gaussian_noise_sigma=1.2, noise_scale_factor=8, image_transform=None, label_transform=None, joint_transform=None): # how many images are in the dataset self.size = size # blob related members self.length = length self.blob_size_fraction = blob_size_fraction self.n_dim = n_dim self.volume_fraction = volume_fraction # which split {'train', 'test', 'validate'} self.split = split # noise related members self.uniform_noise_range = uniform_noise_range self.gaussian_noise_sigma = float(gaussian_noise_sigma) self.noise_scale_factor = noise_scale_factor # transforms self.image_transform = image_transform self.label_transform = label_transform self.joint_transform = joint_transform # internal split_to_seed = dict(train=0, test=1, validate=2) self.master_seed = split_to_seed[self.split]*self.size def __getitem__(self, index): # generate the labels label = skimage.data.binary_blobs( length=self.length, blob_size_fraction=self.blob_size_fraction, n_dim=self.n_dim, volume_fraction=self.volume_fraction, seed=self.master_seed + index) # make the raw image [-1,1] image = label.astype('float32')*2 image -= 1 # add uniform noise low, high = self.uniform_noise_range uniform_noise = numpy.random.uniform(low=low, high=high, size=image.size) image += uniform_noise.reshape(image.shape) # add gaussian noise gaussian_noise = numpy.random.normal(scale=self.gaussian_noise_sigma, size=image.size) image += gaussian_noise.reshape(image.shape) # generate noise at lower scales small_shape = [s//self.noise_scale_factor for s in label.shape] small_size = reduce(mul, small_shape, 1) small_noise_img = numpy.random.uniform(low=low, high=high, size=small_size) small_noise_img = small_noise_img.reshape(small_shape) gaussian_noise = numpy.random.normal(scale=self.gaussian_noise_sigma, size=small_size) small_noise_img += gaussian_noise.reshape(small_shape) noise_img = skimage.transform.resize(image = small_noise_img, output_shape=image.shape, mode='reflect') image += noise_img image -= image.mean() image /= image.std() label = label.astype('long') try: # Apply transforms if self.image_transform is not None: image = self.image_transform(image) if self.label_transform is not None: label = self.label_transform(label) if self.joint_transform is not None: image, label = self.joint_transform(image, label) except Exception: print("[!] An Exception occurred while applying the transforms at " "index {} of split '{}'.".format(index, self.split)) raise image = image[None,...] return image, label def __len__(self): return self.size def get_binary_blob_loaders(train_batch_size=1, test_batch_size=1, num_workers=1, train_image_transform=None, train_label_transform=None, train_joint_transform=None, validate_image_transform=None, validate_label_transform=None, validate_joint_transform=None, test_image_transform=None, test_label_transform=None, test_joint_transform=None, **kwargs): trainset = BinaryBlobs(split='train', image_transform=train_image_transform, label_transform=train_label_transform, joint_transform=train_joint_transform, **kwargs) testset = BinaryBlobs(split='test', image_transform=test_image_transform, label_transform=test_label_transform, joint_transform=test_joint_transform, **kwargs) validset = BinaryBlobs(split='validate',image_transform=validate_image_transform, label_transform=validate_label_transform, joint_transform=validate_joint_transform, **kwargs) trainloader = data.DataLoader(trainset, batch_size=train_batch_size, num_workers=num_workers) testloader = data.DataLoader(testset, batch_size=test_batch_size, num_workers=num_workers) validloader = data.DataLoader(validset, batch_size=test_batch_size, num_workers=num_workers) return trainloader, testloader, validloader if __name__ == "__main__": ds = BinaryBlobs() ds[0] ================================================ FILE: inferno/io/box/camvid.py ================================================ # Adapted from felixgwu's PR here: # https://github.com/felixgwu/vision/blob/cf491d301f62ae9c77ff7250fb7def5cd55ec963/torchvision/datasets/camvid.py import os import torch import torch.utils.data as data import numpy as np from PIL import Image from torchvision.datasets.folder import default_loader from ...utils.exceptions import assert_ from ..transform.base import Compose from ..transform.generic import Normalize, NormalizeRange, Cast, AsTorchBatch, Label2OneHot from ..transform.image import \ RandomSizedCrop, RandomGammaCorrection, RandomFlip, Scale, PILImage2NumPyArray try: from torchvision.datasets.folder import is_image_file except ImportError: from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension def is_image_file(filename): return has_file_allowed_extension(filename, IMG_EXTENSIONS) CAMVID_CLASSES = ['Sky', 'Building', 'Column-Pole', 'Road', 'Sidewalk', 'Tree', 'Sign-Symbol', 'Fence', 'Car', 'Pedestrain', 'Bicyclist', 'Void'] # weights when using median frequency balancing used in SegNet paper # https://arxiv.org/pdf/1511.00561.pdf # The numbers were generated by: # https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua CAMVID_CLASS_WEIGHTS = [0.58872014284134, 0.51052379608154, 2.6966278553009, 0.45021694898605, 1.1785038709641, 0.77028578519821, 2.4782588481903, 2.5273461341858, 1.0122526884079, 3.2375309467316, 4.1312313079834, 0] # mean and std CAMVID_MEAN = [0.41189489566336, 0.4251328133025, 0.4326707089857] CAMVID_STD = [0.27413549931506, 0.28506257482912, 0.28284674400252] CAMVID_CLASS_COLORS = [ (128, 128, 128), (128, 0, 0), (192, 192, 128), (128, 64, 128), (0, 0, 192), (128, 128, 0), (192, 128, 128), (64, 64, 128), (64, 0, 128), (64, 64, 0), (0, 128, 192), (0, 0, 0), ] def make_dataset(dir): images = [] for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) item = path images.append(item) return images def label_to_long_tensor(pic): label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) label = label.view(pic.size[1], pic.size[0], 1) label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long() return label def label_to_pil_image(label): label = label.unsqueeze(0) colored_label = torch.zeros(3, label.size(1), label.size(2)).byte() for i, color in enumerate(CAMVID_CLASS_COLORS): mask = label.eq(i) for j in range(3): colored_label[j].masked_fill_(mask, color[j]) npimg = colored_label.numpy() npimg = np.transpose(npimg, (1, 2, 0)) mode = None if npimg.shape[2] == 1: npimg = npimg[:, :, 0] mode = "L" return Image.fromarray(npimg, mode=mode) class CamVid(data.Dataset): SPLIT_NAME_MAPPING = {'train': 'train', 'training': 'train', 'validate': 'val', 'val': 'val', 'validation': 'val', 'test': 'test', 'testing': 'test'} # Dataset statistics CLASS_WEIGHTS = CAMVID_CLASS_WEIGHTS CLASSES = CAMVID_CLASSES MEAN = CAMVID_MEAN STD = CAMVID_STD def __init__(self, root, split='train', image_transform=None, label_transform=None, joint_transform=None, download=False, loader=default_loader): # Validate assert_(split in self.SPLIT_NAME_MAPPING.keys(), "`split` must be one of {}".format(set(self.SPLIT_NAME_MAPPING.keys())), KeyError) # Root directory and split self.root_directory = root self.split = self.SPLIT_NAME_MAPPING.get(split) # Utils self.image_loader = loader # Transforms self.image_transform = image_transform self.label_transform = label_transform self.joint_transform = joint_transform # For when we implement download: if download: self.download() # Make dataset with paths to the image self.image_paths = make_dataset(os.path.join(self.root_directory, self.split)) def __getitem__(self, index): path = self.image_paths[index] image = self.image_loader(path) label = Image.open(path.replace(self.split, self.split + 'annot')) # Apply transforms if self.image_transform is not None: image = self.image_transform(image) if self.label_transform is not None: label = self.label_transform(label) if self.joint_transform is not None: image, label = self.joint_transform(image, label) return image, label def __len__(self): return len(self.image_paths) def download(self): # TODO: please download the dataset from # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid raise NotImplementedError # noinspection PyTypeChecker def get_camvid_loaders(root_directory, image_shape=(360, 480), labels_as_onehot=False, train_batch_size=1, validate_batch_size=1, test_batch_size=1, num_workers=2): # Make transforms image_transforms = Compose(PILImage2NumPyArray(), NormalizeRange(), RandomGammaCorrection(), Normalize(mean=CAMVID_MEAN, std=CAMVID_STD)) label_transforms = PILImage2NumPyArray() joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0), preserve_aspect_ratio=True), # Scale raw image back to the original shape Scale(output_image_shape=image_shape, interpolation_order=3, apply_to=[0]), # Scale segmentation back to the original shape # (without interpolation) Scale(output_image_shape=image_shape, interpolation_order=0, apply_to=[1]), RandomFlip(allow_ud_flips=False), # Cast raw image to float Cast('float', apply_to=[0])) if labels_as_onehot: # See cityscapes loader to understand why this is here. joint_transforms\ .add(Label2OneHot(num_classes=len(CAMVID_CLASS_WEIGHTS), dtype='bool', apply_to=[1]))\ .add(Cast('float', apply_to=[1])) else: # Cast label image to long joint_transforms.add(Cast('long', apply_to=[1])) # Batchify joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False)) # Build datasets train_dataset = CamVid(root_directory, split='train', image_transform=image_transforms, label_transform=label_transforms, joint_transform=joint_transforms) validate_dataset = CamVid(root_directory, split='validate', image_transform=image_transforms, label_transform=label_transforms, joint_transform=joint_transforms) test_dataset = CamVid(root_directory, split='test', image_transform=image_transforms, label_transform=label_transforms, joint_transform=joint_transforms) # Build loaders train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) return train_loader, validate_loader, test_loader ================================================ FILE: inferno/io/box/cifar.py ================================================ import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data.sampler import SubsetRandomSampler def get_cifar10_loaders(root_directory, train_batch_size=128, test_batch_size=256, download=False, augment=False, validation_dataset_size=None): # Data preparation for CIFAR10. if augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root=os.path.join(root_directory, 'data'), train=True, download=download, transform=transform_train) if validation_dataset_size: indices = torch.randperm(len(trainset)) train_indices = indices[:(len(indices) - validation_dataset_size)] valid_indices = indices[(len(indices) - validation_dataset_size):] validset = torchvision.datasets.CIFAR10(root=os.path.join(root_directory, 'data'), train=True, download=download, transform=transform_test) trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, pin_memory=True, num_workers=1, sampler=SubsetRandomSampler(train_indices)) validloader = torch.utils.data.DataLoader(validset, batch_size=test_batch_size, pin_memory=True, num_workers=1, sampler=SubsetRandomSampler(valid_indices)) else: trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, pin_memory=True, num_workers=1) testset = torchvision.datasets.CIFAR10(root=os.path.join(root_directory, 'data'), train=False, download=download, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, pin_memory=True, num_workers=1) if validation_dataset_size: return trainloader, validloader, testloader else: return trainloader, testloader def get_cifar100_loaders(root_directory, train_batch_size=128, test_batch_size=100, download=False, augment=False, validation_dataset_size=None): # Data preparation for CIFAR100. Adapted from # https://github.com/kuangliu/pytorch-cifar/blob/master/main.py if augment: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)), ]) else: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)), ]) trainset = torchvision.datasets.CIFAR100(root=os.path.join(root_directory, 'data'), train=True, download=download, transform=transform_train) if validation_dataset_size: indices = torch.randperm(len(trainset)) train_indices = indices[:(len(indices) - validation_dataset_size)] valid_indices = indices[(len(indices) - validation_dataset_size):] validset = torchvision.datasets.CIFAR100(root=os.path.join(root_directory, 'data'), train=True, download=download, transform=transform_test) trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, pin_memory=True, num_workers=1, sampler=SubsetRandomSampler(train_indices)) validloader = torch.utils.data.DataLoader(validset, batch_size=test_batch_size, pin_memory=True, num_workers=1, sampler=SubsetRandomSampler(valid_indices)) else: trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, pin_memory=True, num_workers=1) testset = torchvision.datasets.CIFAR100(root=os.path.join(root_directory, 'data'), train=False, download=download, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, pin_memory=True, num_workers=1) if validation_dataset_size: return trainloader, validloader, testloader else: return trainloader, testloader ================================================ FILE: inferno/io/box/cityscapes.py ================================================ import zipfile import io import os import torch.utils.data as data from PIL import Image from os.path import join, relpath, abspath from ...utils.exceptions import assert_ from ..transform.base import Compose from ..transform.generic import \ Normalize, NormalizeRange, Cast, AsTorchBatch, Project, Label2OneHot from ..transform.image import \ RandomSizedCrop, RandomGammaCorrection, RandomFlip, Scale, PILImage2NumPyArray from ..core import Concatenate CITYSCAPES_CLASSES = { 0: 'unlabeled', 1: 'ego vehicle', 2: 'rectification border', 3: 'out of roi', 4: 'static', 5: 'dynamic', 6: 'ground', 7: 'road', 8: 'sidewalk', 9: 'parking', 10: 'rail track', 11: 'building', 12: 'wall', 13: 'fence', 14: 'guard rail', 15: 'bridge', 16: 'tunnel', 17: 'pole', 18: 'polegroup', 19: 'traffic light', 20: 'traffic sign', 21: 'vegetation', 22: 'terrain', 23: 'sky', 24: 'person', 25: 'rider', 26: 'car', 27: 'truck', 28: 'bus', 29: 'caravan', 30: 'trailer', 31: 'train', 32: 'motorcycle', 33: 'bicycle', -1: 'license plate' } IGNORE_CLASS_LABEL = 19 # Class labels to use for training, found here: # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py#L61 CITYSCAPES_CLASSES_TO_LABELS = { 0: IGNORE_CLASS_LABEL, 1: IGNORE_CLASS_LABEL, 2: IGNORE_CLASS_LABEL, 3: IGNORE_CLASS_LABEL, 4: IGNORE_CLASS_LABEL, 5: IGNORE_CLASS_LABEL, 6: IGNORE_CLASS_LABEL, 7: 0, 8: 1, 9: IGNORE_CLASS_LABEL, 10: IGNORE_CLASS_LABEL, 11: 2, 12: 3, 13: 4, 14: IGNORE_CLASS_LABEL, 15: IGNORE_CLASS_LABEL, 16: IGNORE_CLASS_LABEL, 17: 5, 18: IGNORE_CLASS_LABEL, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: IGNORE_CLASS_LABEL, 30: IGNORE_CLASS_LABEL, 31: 16, 32: 17, 33: 18, -1: IGNORE_CLASS_LABEL } # Map classes to official cityscapes colors CITYSCAPES_CLASS_COLOR_MAPPING = { 0: (0, 0, 0), 1: (0, 0, 0), 2: (0, 0, 0), 3: (0, 0, 0), 4: (0, 0, 0), 5: (111, 74, 0), 6: (81, 0, 81), 7: (128, 64, 128), 8: (244, 35, 232), 9: (250, 170, 160), 10: (230, 150, 140), 11: (70, 70, 70), 12: (102, 102, 156), 13: (190, 153, 153), 14: (180, 165, 180), 15: (150, 100, 100), 16: (150, 120, 90), 17: (153, 153, 153), 18: (153, 153, 153), 19: (250, 170, 30), 20: (220, 220, 0), 21: (107, 142, 35), 22: (152, 251, 152), 23: (70, 130, 180), 24: (220, 20, 60), 25: (255, 0, 0), 26: (0, 0, 142), 27: (0, 0, 70), 28: (0, 60, 100), 29: (0, 0, 90), 30: (0, 0, 110), 31: (0, 80, 100), 32: (0, 0, 230), 33: (119, 11, 32), -1: (0, 0, 142), } # Weights corresponding to the outputs CITYSCAPES_LABEL_WEIGHTS = { 0: 1., 1: 1., 2: 1., 3: 1., 4: 1., 5: 1., 6: 1., 7: 1., 8: 1., 9: 1., 10: 1., 11: 1., 12: 1., 13: 1., 14: 1., 15: 1., 16: 1., 17: 1., 18: 1., 19: 0. } # 0:void 1:flat 2:construction 3:object 4:nature 5:sky 6:human 7:vehicle CITYSCAPES_CATEGORIES = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7] CITYSCAPES_IGNORE_IN_EVAL = [True, True, True, True, True, True, True, False, False, True, True, False, False, False, True, True, True, False, True, False, False, False, False, False, False, False, False, False, False, True, True, False, False, False, True] # mean and std CITYSCAPES_MEAN = [0.28689554, 0.32513303, 0.28389177] CITYSCAPES_STD = [0.18696375, 0.19017339, 0.18720214] def get_matching_labelimage_file(f, groundtruth): fs = f.split('/') fs[0] = groundtruth fs[-1] = str.replace(fs[-1], 'leftImg8bit', groundtruth + '_labelIds') return '/'.join(fs) def get_filelist(path): if path.endswith('.zip'): return zipfile.ZipFile(path, 'r').filelist elif os.path.isdir(path): return [relpath(join(root, filename), abspath(join(path, '..'))) for root, _, filenames in os.walk(path) for filename in filenames] else: raise NotImplementedError("Path must be a zip archive or a directory.") def make_dataset(path, split): images = [] for f in get_filelist(path): if isinstance(f, str): fn = f fns = f.split('/') else: fn = f.filename fns = f.filename.split('/') if fns[-1].endswith('.png') and fns[1] == split: # use first folder name to identify train/val/test images if split == 'train_extra': groundtruth = 'gtCoarse' else: groundtruth = 'gtFine' fl = get_matching_labelimage_file(fn, groundtruth) images.append((f, fl)) return images def extract_image(path, image_path): if path.endswith('.zip'): # read image directly from zipfile if path is a zip return Image.open(io.BytesIO(zipfile.ZipFile(path, 'r').read(image_path))) else: return Image.open(join(abspath(join(path, '..')), image_path), 'r') class Cityscapes(data.Dataset): SPLIT_NAME_MAPPING = {'train': 'train', 'training': 'train', 'validate': 'val', 'val': 'val', 'validation': 'val', 'test': 'test', 'testing': 'test', 'training_extra': 'train_extra', 'train_extra': 'train_extra'} # Dataset statistics CLASSES = CITYSCAPES_CLASSES MEAN = CITYSCAPES_MEAN STD = CITYSCAPES_STD BLACKLIST = ['leftImg8bit/train_extra/troisdorf/troisdorf_000000_000073_leftImg8bit.png'] def __init__(self, root_folder, split='train', read_from_zip_archive=True, image_transform=None, label_transform=None, joint_transform=None): """ Parameters: root_folder: folder that contains both leftImg8bit_trainvaltest.zip and gtFine_trainvaltest.zip archives. split: name of dataset spilt (i.e. 'train_extra', 'train', 'val' or 'test') """ assert_(split in self.SPLIT_NAME_MAPPING.keys(), "`split` must be one of {}".format(set(self.SPLIT_NAME_MAPPING.keys())), KeyError) self.split = self.SPLIT_NAME_MAPPING.get(split) self.read_from_zip_archive = read_from_zip_archive # Get roots self.image_root, self.label_root = [join(root_folder, groot) for groot in self.get_image_and_label_roots()] # Transforms self.image_transform = image_transform self.label_transform = label_transform self.joint_transform = joint_transform # Make list with paths to the images self.image_paths = make_dataset(self.image_root, self.split) def __getitem__(self, index): pi, pl = self.image_paths[index] if pi in self.BLACKLIST: # Select the next image if the current image is bad return self[index + 1] image = extract_image(self.image_root, pi) label = extract_image(self.label_root, pl) try: # Apply transforms if self.image_transform is not None: image = self.image_transform(image) if self.label_transform is not None: label = self.label_transform(label) if self.joint_transform is not None: image, label = self.joint_transform(image, label) except Exception: print("[!] An Exception occurred while applying the transforms at " "index {} of split '{}'.".format(index, self.split)) raise return image, label def __len__(self): return len(self.image_paths) def download(self): # TODO: please download the dataset from # https://www.cityscapes-dataset.com/ raise NotImplementedError def get_image_and_label_roots(self): all_roots = { 'zipped': { 'train': ('leftImg8bit_trainvaltest.zip', 'gtFine_trainvaltest.zip'), 'val': ('leftImg8bit_trainvaltest.zip', 'gtFine_trainvaltest.zip'), 'train_extra': ('leftImg8bit_trainextra.zip', 'gtCoarse.zip') }, 'unzipped': { 'train': ('leftImg8bit', 'gtFine'), 'val': ('leftImg8bit', 'gtFine'), 'train_extra': ('leftImg8bit', 'gtCoarse') } } image_and_label_roots = all_roots\ .get('zipped' if self.read_from_zip_archive else 'unzipped').get(self.split) return image_and_label_roots def make_transforms(image_shape, labels_as_onehot): # Make transforms image_transforms = Compose(PILImage2NumPyArray(), NormalizeRange(), RandomGammaCorrection(), Normalize(mean=CITYSCAPES_MEAN, std=CITYSCAPES_STD)) label_transforms = Compose(PILImage2NumPyArray(), Project(projection=CITYSCAPES_CLASSES_TO_LABELS)) joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0), preserve_aspect_ratio=True), # Scale raw image back to the original shape Scale(output_image_shape=image_shape, interpolation_order=3, apply_to=[0]), # Scale segmentation back to the original shape # (without interpolation) Scale(output_image_shape=image_shape, interpolation_order=0, apply_to=[1]), RandomFlip(allow_ud_flips=False), # Cast raw image to float Cast('float', apply_to=[0])) if labels_as_onehot: # Applying Label2OneHot on the full label image makes it unnecessarily expensive, # because we're throwing it away with RandomSizedCrop and Scale. Tests show that it's # ~1 sec faster per image. joint_transforms \ .add(Label2OneHot(num_classes=len(CITYSCAPES_LABEL_WEIGHTS), dtype='bool', apply_to=[1])) \ .add(Cast('float', apply_to=[1])) else: # Cast label image to long joint_transforms.add(Cast('long', apply_to=[1])) # Batchify joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False)) # Return as kwargs return {'image_transform': image_transforms, 'label_transform': label_transforms, 'joint_transform': joint_transforms} def get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), labels_as_onehot=False, include_coarse_dataset=False, read_from_zip_archive=True, train_batch_size=1, validate_batch_size=1, num_workers=2): # Build datasets train_dataset = Cityscapes(root_directory, split='train', read_from_zip_archive=read_from_zip_archive, **make_transforms(image_shape, labels_as_onehot)) if include_coarse_dataset: # Build coarse dataset coarse_dataset = Cityscapes(root_directory, split='train_extra', read_from_zip_archive=read_from_zip_archive, **make_transforms(image_shape, labels_as_onehot)) # ... and concatenate with train_dataset train_dataset = Concatenate(coarse_dataset, train_dataset) validate_dataset = Cityscapes(root_directory, split='validate', read_from_zip_archive=read_from_zip_archive, **make_transforms(image_shape, labels_as_onehot)) # Build loaders train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) return train_loader, validate_loader ================================================ FILE: inferno/io/core/__init__.py ================================================ from .base import SyncableDataset from .zip import Zip, ZipReject from .concatenate import Concatenate ================================================ FILE: inferno/io/core/base.py ================================================ from torch.utils.data.dataset import Dataset class SyncableDataset(Dataset): def __init__(self, base_sequence=None): self.base_sequence = base_sequence def sync_with(self, dataset): if hasattr(dataset, 'base_sequence'): self.base_sequence = dataset.base_sequence return self def __len__(self): if self.base_sequence is None: raise RuntimeError("Class {} does not specify a base sequence. Either specify " "one by assigning to self.base_sequence or override the " "__len__ method.".format(self.__class__.__name__)) else: return len(self.base_sequence) class IndexSpec(object): """ Class to wrap any extra index information a `Dataset` object might want to send back. This could be useful in (say) inference, where we would wish to (asynchronously) know more about the current input. """ def __init__(self, index=None, base_sequence_at_index=None): self.index = index self.base_sequence_at_index = base_sequence_at_index def __int__(self): return int(self.index) ================================================ FILE: inferno/io/core/concatenate.py ================================================ import numpy as np from torch.utils.data.dataset import Dataset from ...utils import python_utils as pyu class Concatenate(Dataset): """ Concatenates mutliple datasets to one. This class does not implement synchronization primitives. """ def __init__(self, *datasets, transforms=None): assert all([isinstance(dataset, Dataset) for dataset in datasets]) assert len(datasets) >= 1 assert transforms is None or callable(transforms) self.datasets = datasets self.transforms = transforms def map_index(self, index): # Get a list of lengths of all datasets. Say the answer is [4, 3, 3], # and we're looking for index = 5. len_list = list(map(len, self.datasets)) # Cumulate to a numpy array. The answer is [4, 7, 10] cumulative_len_list = np.cumsum(len_list) # When the index is subtracted, we get [-1, 2, 5]. We're looking for the (index # of the) first cumulated len which is larger than the index (in this case, # 7 (index 1)). offset_cumulative_len_list = cumulative_len_list - index dataset_index = np.argmax(offset_cumulative_len_list > 0) # With the dataset index, we figure out the index in dataset if dataset_index == 0: # First dataset - index corresponds to index_in_dataset index_in_dataset = index else: # Get cumulated length up to the current dataset len_up_to_dataset = cumulative_len_list[dataset_index - 1] # Compute index_in_dataset as that what's left index_in_dataset = index - len_up_to_dataset return dataset_index, index_in_dataset def __getitem__(self, index): assert index < len(self) dataset_index, index_in_dataset = self.map_index(index) fetched = self.datasets[dataset_index][index_in_dataset] if self.transforms is None: return fetched elif callable(self.transforms): return self.transforms(*pyu.to_iterable(fetched)) else: raise NotImplementedError def __len__(self): return sum([len(dataset) for dataset in self.datasets]) def __repr__(self): if len(self.datasets) < 3: return "Concatenate(" + \ ", ".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + ", " + \ self.datasets[-1].__repr__() + \ ")" else: return "Concatenate({}xDatasets)".format(len(self.datasets)) ================================================ FILE: inferno/io/core/data_utils.py ================================================ def implements_sync_primitives(dataset): return hasattr(dataset, 'sync_with') and callable(getattr(dataset, 'sync_with')) def defines_base_sequence(dataset): return hasattr(dataset, 'base_sequence') and dataset.base_sequence is not None ================================================ FILE: inferno/io/core/zip.py ================================================ from torch.utils.data.dataset import Dataset import torch.multiprocessing as mp import numpy as np from . import data_utils as du from .base import SyncableDataset from ...utils.exceptions import assert_ from ...utils import python_utils as pyu import random class Zip(SyncableDataset): """ Zip two or more datasets to one dataset. If the datasets implement synchronization primitives, they are all synchronized with the first dataset. """ def __init__(self, *datasets, sync=False, transforms=None): super(Zip, self).__init__() assert_(len(datasets) >= 1, "Expecting one or more datasets, got none.", ValueError) for dataset_index, dataset in enumerate(datasets): assert_(isinstance(dataset, Dataset), "Object at position {} of type {} is not a subclass of " "`torch.utils.data.dataset.Dataset`" .format(dataset_index, type(dataset).__name__), TypeError) assert_(transforms is None or callable(transforms), "Given `transforms` is not callable.", TypeError) self.datasets = datasets self.sync = sync self.transforms = transforms if self.sync: self.sync_datasets() # Inherit base sequence if sync'ing if self.sync and all([du.defines_base_sequence(dataset) for dataset in self.datasets]): self.base_sequence = list(zip(*[dataset.base_sequence for dataset in self.datasets])) else: self.base_sequence = None def sync_datasets(self): master_dataset = self.datasets[0] for dataset in self.datasets[1:]: if du.implements_sync_primitives(dataset): dataset.sync_with(master_dataset) def sync_with(self, dataset): master_dataset = self.datasets[0] if du.implements_sync_primitives(master_dataset): master_dataset.sync_with(dataset) # Sync all other datasets self.sync_datasets() def __getitem__(self, index): assert_(index < len(self), exception_type=IndexError) fetched = [dataset[index] for dataset in self.datasets] if self.transforms is None: return fetched elif callable(self.transforms): return self.transforms(*fetched) else: raise RuntimeError def __len__(self): if du.defines_base_sequence(self): return super(Zip, self).__len__() else: return min([len(dataset) for dataset in self.datasets]) def __repr__(self): if len(self.datasets) > 3: return "{}({}xDatasets)".format(type(self).__name__, len(self.datasets)) else: return "{}(".format(type(self).__name__) + \ ", ".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + ", " + \ self.datasets[-1].__repr__() + \ ')' class ZipReject(Zip): """ Extends `Zip` by the functionality of rejecting samples that don't fulfill a specified rejection criterion. """ def __init__(self, *datasets, sync=False, transforms=None, rejection_dataset_indices, rejection_criterion, random_jump_after_reject=True): """ Parameters ---------- datasets : list or tuple Datasets to zip. sync : bool Whether to synchronize zipped datasets if a synchronization primitive is available. transforms : callable Transforms to apply on the fetched batch. rejection_dataset_indices : int or list or tuple Indices (or index) corresponding to the datasets which are used to determine whether a batch should be rejected. rejection_criterion : callable Criterion for rejection of batch. Must be a callable that accepts one or more arrays / tensors and returns True if the corresponding batch should be rejected, False otherwise. Should accept as many inputs as the number of elements in `rejection_dataset_indices` if the latter is a list, and 1 otherwise. Note that the order of the inputs to the `rejection_criterion` is the same as the order of the indices in `rejection_dataset_indices`. random_jump_after_reject: bool Whether to try a random index or the rejected index incremented by one after rejection. """ super(ZipReject, self).__init__(*datasets, sync=sync, transforms=transforms) for rejection_dataset_index in pyu.to_iterable(rejection_dataset_indices): assert_(rejection_dataset_index < len(datasets), "Index of the dataset to be used for rejection (= {}) is larger " "than the number of datasets (= {}) minus one." .format(rejection_dataset_index, len(datasets)), IndexError) self.rejection_dataset_indices = pyu.to_iterable(rejection_dataset_indices) assert_(callable(rejection_criterion), "Rejection criterion is not callable as it should be.", TypeError) # return true if fetched should be rejected self.rejection_criterion = rejection_criterion # Array shared over processes to keep track of which indices have been rejected self.rejected = mp.Array('b', len(self)) self.available_indices = None # optional index mapping to exclude rejected indices, reducing dataset size (see remove_rejected()) self.index_mapping = None self.random_jump_after_reject = random_jump_after_reject def remove_rejected(self): # remove the indices belonging to samples that were rejected from the dataset # this changes the length of the dataset rejected = np.array(self.rejected[:]) self.index_mapping = np.argwhere(1 - rejected)[:, 0] self.rejected = mp.Array('b', len(self)) # just in case of num_workers == 0 self.available_indices = None def __len__(self): if hasattr(self, 'index_mapping') and self.index_mapping is not None: return len(self.index_mapping) else: return super(ZipReject, self).__len__() def next_index_to_try(self, index): if self.random_jump_after_reject: return np.random.randint(len(self)) else: return (index + 1) % len(self) def fetch_from_rejection_datasets(self, index): rejection_fetched = [self.datasets[rejection_dataset_index][index] for rejection_dataset_index in self.rejection_dataset_indices] return rejection_fetched def __getitem__(self, index): # we increase the index until a valid batch of 'rejection_dataset' is found assert_(index < len(self), exception_type=IndexError) index_ = index # if we have a rejection dataset, check if the rejection criterion is fulfilled # and update the index if self.rejection_dataset_indices is not None: # at the start of each epoch, compute the available indices from the shared variable if self.available_indices is None: self.available_indices = set(np.argwhere(1 - np.array(self.rejected[:]))[:, 0]) reject = True while reject: # check if there are no potentially valid indices left if not self.available_indices: raise RuntimeError("ZipReject: No valid batch was found!") # check if this index was marked as rejected before if index_ not in self.available_indices: index_ = self.next_index_to_try(index_) continue # check if this index was marked as rejected in any process if self.rejected[index_]: self.available_indices.remove(index_) continue # map the index, if an index_mapping has been defined (see remove_rejected()) mapped_index_ = index_ if self.index_mapping is None else self.index_mapping[index_] # we only fetch the dataset which has the rejection criterion # and only fetch all datasets when a valid index is found rejection_fetched = self.fetch_from_rejection_datasets(mapped_index_) # check if this batch is to be rejected reject = self.rejection_criterion(*rejection_fetched) # if so, increase the index and add it if reject: self.rejected[index_] = True self.available_indices.remove(index_) # fetch all other datasets and concatenate them with the valid rejection_fetch fetched = [] for dataset_index, dataset in enumerate(self.datasets): if dataset_index in self.rejection_dataset_indices: # Find the index in `rejection_fetched` corresponding to this dataset_index index_in_rejection_fetched = self.rejection_dataset_indices.index(dataset_index) # ... and append to fetched fetched.append(rejection_fetched[index_in_rejection_fetched]) else: # Fetch and append to fetched fetched.append(dataset[mapped_index_]) else: # map the index, if an index_mapping has been defined (see remove_rejected()) mapped_index_ = index_ if self.index_mapping is None else self.index_mapping[index_] fetched = [dataset[mapped_index_] for dataset in self.datasets] # apply transforms if present if self.transforms is not None: assert_(callable(self.transforms), "`self.transforms` is not callable.", TypeError) fetched = self.transforms(*fetched) return fetched ================================================ FILE: inferno/io/transform/__init__.py ================================================ from .base import Transform, Compose from . import generic from . import image from . import volume ================================================ FILE: inferno/io/transform/base.py ================================================ from ...utils import python_utils as pyu import numpy as np class Transform(object): """ Base class for a Transform. The argument `apply_to` (list) specifies the indices of the tensors this transform will be applied to. The following methods are recognized (in order of descending priority): - `batch_function`: Applies to all tensors in a batch simultaneously - `tensor_function`: Applies to just __one__ tensor at a time. - `volume_function`: For 3D volumes, applies to just __one__ volume at a time. - `image_function`: For 2D or 3D volumes, applies to just __one__ image at a time. For example, if both `volume_function` and `image_function` are defined, this means that only the former will be called. If the inputs are therefore not 5D batch-tensors of 3D volumes, a `NotImplementedError` is raised. """ def __init__(self, apply_to=None): """ Parameters ---------- apply_to : list or tuple Indices of tensors to apply this transform to. The indices are with respect to the list of arguments this object is called with. """ self._random_variables = {} self._apply_to = list(apply_to) if apply_to is not None else None def build_random_variables(self, **kwargs): pass def clear_random_variables(self): self._random_variables = {} def get_random_variable(self, key, default=None, build=True, **random_variable_building_kwargs): if key in self._random_variables: return self._random_variables.get(key, default) else: if not build: return default else: self.build_random_variables(**random_variable_building_kwargs) return self.get_random_variable(key, default, build=False) def set_random_variable(self, key, value): self._random_variables.update({key: value}) def __call__(self, *tensors, **transform_function_kwargs): tensors = pyu.to_iterable(tensors) # Get the list of the indices of the tensors to which we're going to apply the transform apply_to = list(range(len(tensors))) if self._apply_to is None else self._apply_to # Flush random variables and assume they're built by image_function self.clear_random_variables() if hasattr(self, 'batch_function'): transformed = self.batch_function(tensors, **transform_function_kwargs) return pyu.from_iterable(transformed) elif hasattr(self, 'tensor_function'): transformed = [self._apply_tensor_function(tensor, **transform_function_kwargs) if tensor_index in apply_to else tensor for tensor_index, tensor in enumerate(tensors)] return pyu.from_iterable(transformed) elif hasattr(self, 'volume_function'): # Loop over all tensors transformed = [self._apply_volume_function(tensor, **transform_function_kwargs) if tensor_index in apply_to else tensor for tensor_index, tensor in enumerate(tensors)] return pyu.from_iterable(transformed) elif hasattr(self, 'image_function'): # Loop over all tensors transformed = [self._apply_image_function(tensor, **transform_function_kwargs) if tensor_index in apply_to else tensor for tensor_index, tensor in enumerate(tensors)] return pyu.from_iterable(transformed) else: raise NotImplementedError # noinspection PyUnresolvedReferences def _apply_tensor_function(self, tensor, **transform_function_kwargs): if isinstance(tensor, list): return [self._apply_tensor_function(tens) for tens in tensor] return self.tensor_function(tensor) # noinspection PyUnresolvedReferences def _apply_image_function(self, tensor, **transform_function_kwargs): assert pyu.has_callable_attr(self, 'image_function') if isinstance(tensor, list): return [self._apply_image_function(tens) for tens in tensor] # 2D case if tensor.ndim == 4: return np.array([np.array([self.image_function(image, **transform_function_kwargs) for image in channel_image]) for channel_image in tensor]) # 3D case elif tensor.ndim == 5: return np.array([np.array([np.array([self.image_function(image, **transform_function_kwargs) for image in volume]) for volume in channel_volume]) for channel_volume in tensor]) elif tensor.ndim == 3: # Assume we have a 3D volume (signature zyx) and apply the image function # on all yx slices. return np.array([self.image_function(image, **transform_function_kwargs) for image in tensor]) elif tensor.ndim == 2: # Assume we really do have an image. return self.image_function(tensor, **transform_function_kwargs) else: raise NotImplementedError # noinspection PyUnresolvedReferences def _apply_volume_function(self, tensor, **transform_function_kwargs): assert pyu.has_callable_attr(self, 'volume_function') if isinstance(tensor, list): return [self._apply_volume_function(tens) for tens in tensor] # 3D case if tensor.ndim == 5: # tensor is bczyx # volume function is applied to zyx, i.e. loop over b and c # FIXME This loops one time too many return np.array([np.array([np.array([self.volume_function(volume, **transform_function_kwargs) for volume in channel_volume]) for channel_volume in batch]) for batch in tensor]) elif tensor.ndim == 4: # We're applying the volume function on a czyx tensor, i.e. we loop over c and apply # volume function to (zyx) return np.array([self.volume_function(volume, **transform_function_kwargs) for volume in tensor]) elif tensor.ndim == 3: # We're applying the volume function on the volume itself return self.volume_function(tensor, **transform_function_kwargs) else: cname = self.__class__.__name__ raise NotImplementedError("Volume function not implemented for ndim %i called in %s" % (tensor.ndim, cname)) class Compose(object): """Composes multiple callables (including but not limited to `Transform` objects).""" def __init__(self, *transforms): """ Parameters ---------- transforms : list of callable or tuple of callable Transforms to compose. """ assert all([callable(transform) for transform in transforms]) self.transforms = list(transforms) def add(self, transform): assert callable(transform) self.transforms.append(transform) return self def remove(self, name): transform_idx = None for idx, transform in enumerate(self.transforms): if type(transform).__name__ == name: transform_idx = idx break if transform_idx is not None: self.transforms.pop(transform_idx) return self def __call__(self, *tensors): intermediate = tensors for transform in self.transforms: intermediate = pyu.to_iterable(transform(*intermediate)) return pyu.from_iterable(intermediate) class DTypeMapping(object): DTYPE_MAPPING = {'float32': 'float32', 'float': 'float32', 'double': 'float64', 'float64': 'float64', 'half': 'float16', 'float16': 'float16', 'long': 'int64', 'int64': 'int64', 'byte': 'uint8', 'uint8': 'uint8', 'int': 'int32', 'int32': 'int32'} ================================================ FILE: inferno/io/transform/generic.py ================================================ import numpy as np import torch from .base import Transform, DTypeMapping from ...utils.exceptions import assert_, DTypeError class Normalize(Transform): """Normalizes input to zero mean unit variance.""" def __init__(self, eps=1e-4, mean=None, std=None, ignore_value=None, **super_kwargs): """ Parameters ---------- eps : float A small epsilon for numerical stability. mean : list or float or numpy.ndarray Global dataset mean for all channels. std : list or float or numpy.ndarray Global dataset std for all channels. super_kwargs : dict Kwargs to the superclass `inferno.io.transform.base.Transform`. """ super(Normalize, self).__init__(**super_kwargs) self.eps = eps self.mean = np.asarray(mean) if mean is not None else None self.std = np.asarray(std) if std is not None else None self.ignore_value = ignore_value def tensor_function(self, tensor): # if we have a background value that we don't want to normalize mask = None if self.ignore_value is None else (tensor != self.ignore_value) if mask is None: mean = np.asarray(tensor.mean()) if self.mean is None else self.mean std = np.asarray(tensor.std()) if self.std is None else self.std else: mean = np.asarray(tensor[mask].mean()) if self.mean is None else self.mean std = np.asarray(tensor[mask].std()) if self.std is None else self.std # Figure out how to reshape mean and std reshape_as = [-1] + [1] * (tensor.ndim - 1) # Normalize if mask is None: tensor = (tensor - mean.reshape(*reshape_as)) / (std.reshape(*reshape_as) + self.eps) else: # if tensor is int, the normalized tensor will be in int as well tensor = tensor.astype('float64') tensor[mask] = ((tensor - mean.reshape(*reshape_as)) \ / (std.reshape(*reshape_as) + self.eps))[mask] return tensor class NormalizeRange(Transform): """Normalizes input by a constant.""" def __init__(self, normalize_by=255., **super_kwargs): """ Parameters ---------- normalize_by : float or int Scalar to normalize by. super_kwargs : dict Kwargs to the superclass `inferno.io.transform.base.Transform`. """ super(NormalizeRange, self).__init__(**super_kwargs) self.normalize_by = float(normalize_by) def tensor_function(self, tensor): return tensor / self.normalize_by class Project(Transform): """ Given a projection mapping (i.e. a dict) and an input tensor, this transform replaces all values in the tensor that equal a key in the mapping with the value corresponding to the key. """ def __init__(self, projection, **super_kwargs): """ Parameters ---------- projection : dict The projection mapping. super_kwargs : dict Keywords to the super class. """ super(Project, self).__init__(**super_kwargs) self.projection = dict(projection) def tensor_function(self, tensor): output = np.zeros_like(tensor) for source, target in self.projection.items(): output[tensor == source] = target return output class Label2OneHot(Transform, DTypeMapping): """Convert integer labels to one-hot vectors for arbitrary dimensional data.""" def __init__(self, num_classes, dtype='float', **super_kwargs): """ Parameters ---------- num_classes : int Number of classes. dtype : str Datatype of the output. super_kwargs : dict Keyword arguments to the superclass. """ super(Label2OneHot, self).__init__(**super_kwargs) self.num_classes = num_classes self.dtype = self.DTYPE_MAPPING.get(dtype) def tensor_function(self, tensor): reshaped_arange = np.arange(self.num_classes).reshape(-1, *(1,)*tensor.ndim) output = np.equal(reshaped_arange, tensor).astype(self.dtype) # output = np.zeros(shape=(self.num_classes,) + tensor.shape, dtype=self.dtype) # # Optimizing for simplicity and memory efficiency, because one would usually # # spawn multiple workers # for class_num in range(self.num_classes): # output[class_num] = tensor == class_num return output class Cast(Transform, DTypeMapping): """Casts inputs to a specified datatype.""" def __init__(self, dtype='float', **super_kwargs): """ Parameters ---------- dtype : {'float16', 'float32', 'float64', 'half', 'float', 'double'} Datatype to cast to. super_kwargs : dict Kwargs to the superclass `inferno.io.transform.base.Transform`. """ super(Cast, self).__init__(**super_kwargs) assert dtype in self.DTYPE_MAPPING.keys() self.dtype = self.DTYPE_MAPPING.get(dtype) def tensor_function(self, tensor): return getattr(np, self.dtype)(tensor) class AsTorchBatch(Transform): """Converts a given numpy array to a torch batch tensor. The result is a torch tensor __without__ the leading batch axis. For example, if the input is an image of shape `(100, 100)`, the output is a batch of shape `(1, 100, 100)`. The collate function will add the leading batch axis to obtain a tensor of shape `(N, 1, 100, 100)`, where `N` is the batch-size. """ def __init__(self, dimensionality, add_channel_axis_if_necessary=True, **super_kwargs): """ Parameters ---------- dimensionality : {1, 2, 3} Dimensionality of the data: 1 if vector, 2 if image, 3 if volume. add_channel_axis_if_necessary : bool Whether to add a channel axis where necessary. For example, if `dimensionality = 2` and the input temperature has 2 dimensions (i.e. an image), setting `add_channel_axis_if_necessary` to True results in the output being a 3 dimensional tensor, where the leading dimension is a singleton and corresponds to `channel`. super_kwargs : dict Kwargs to the superclass `inferno.io.transform.base.Transform`. """ super(AsTorchBatch, self).__init__(**super_kwargs) assert dimensionality in [1, 2, 3] self.dimensionality = dimensionality self.add_channel_axis_if_necessary = bool(add_channel_axis_if_necessary) def _to_batch(self, tensor): assert_(isinstance(tensor, np.ndarray), "Expected numpy array, got %s" % type(tensor), DTypeError) if self.dimensionality == 3: # We're dealing with a volume. tensor can either be 3D or 4D assert tensor.ndim in [3, 4] if tensor.ndim == 3 and self.add_channel_axis_if_necessary: # Add channel axis return torch.from_numpy(tensor[None, ...]) else: # Channel axis is in already return torch.from_numpy(tensor) elif self.dimensionality == 2: # We're dealing with an image. tensor can either be 2D or 3D assert tensor.ndim in [2, 3] if tensor.ndim == 2 and self.add_channel_axis_if_necessary: # Add channel axis return torch.from_numpy(tensor[None, ...]) else: # Channel axis is in already return torch.from_numpy(tensor) elif self.dimensionality == 1: # We're dealing with a vector - it has to be 1D assert tensor.ndim == 1 return torch.from_numpy(tensor) else: raise NotImplementedError def tensor_function(self, tensor): assert_(isinstance(tensor, (list, np.ndarray)), "Expected numpy array or list, got %s" % type(tensor), DTypeError) if isinstance(tensor, np.ndarray): return self._to_batch(tensor) else: return [self._to_batch(elem) for elem in tensor] ================================================ FILE: inferno/io/transform/image.py ================================================ import numpy as np from scipy.ndimage import zoom from scipy.ndimage.filters import gaussian_filter from scipy.ndimage.interpolation import map_coordinates, rotate from scipy.ndimage.morphology import binary_dilation, binary_erosion from skimage.exposure import adjust_gamma from warnings import catch_warnings, simplefilter from .base import Transform from ...utils.exceptions import assert_, ShapeError class PILImage2NumPyArray(Transform): """Convert a PIL Image object to a numpy array. For images with multiple channels (say RGB), the channel axis is moved to front. Therefore, a (100, 100, 3) RGB image becomes an array of shape (3, 100, 100). """ def tensor_function(self, tensor): tensor = np.asarray(tensor) if tensor.ndim == 3: # There's a channel axis - we move it to front tensor = np.moveaxis(tensor, source=-1, destination=0) elif tensor.ndim == 2: pass else: raise NotImplementedError("Expected tensor to be a 2D or 3D " "numpy array, got a {}D array instead." .format(tensor.ndim)) return tensor class Scale(Transform): """Scales an image to a given size with spline interpolation of requested order. Unlike torchvision.transforms.Scale, this does not depend on PIL and therefore works with numpy arrays. If you do have a PIL image and wish to use this transform, consider applying `PILImage2NumPyArray` first. Warnings -------- This transform uses `scipy.ndimage.zoom` and requires scipy >= 0.13.0 to work correctly. """ def __init__(self, output_image_shape, interpolation_order=3, zoom_kwargs=None, **super_kwargs): """ Parameters ---------- output_image_shape : list or tuple or int or None Target size of the output image. Aspect ratio may not be preserved. If output_image_shape is None, image input size will be preserved interpolation_order : int Interpolation order for the spline interpolation. zoom_kwargs : dict Keyword arguments for `scipy.ndimage.zoom`. super_kwargs : dict Keyword arguments for the superclass. """ super(Scale, self).__init__(**super_kwargs) if output_image_shape is not None: output_image_shape = (output_image_shape, output_image_shape) \ if isinstance(output_image_shape, int) else tuple(output_image_shape) assert_(len(output_image_shape) == 2, "`output_image_shape` must be an integer or a tuple of length 2.", ValueError) self.output_image_shape = output_image_shape self.interpolation_order = interpolation_order self.zoom_kwargs = {} if zoom_kwargs is None else dict(zoom_kwargs) def image_function(self, image): source_height, source_width = image.shape target_height, target_width = self.output_image_shape # We're on Python 3 - take a deep breath and relax. zoom_height, zoom_width = (target_height / source_height), (target_width / source_width) with catch_warnings(): # Ignore warning that scipy should be > 0.13 (it's 0.19 these days) simplefilter('ignore') rescaled_image = zoom(image, (zoom_height, zoom_width), order=self.interpolation_order, **self.zoom_kwargs) # This should never happen assert_(rescaled_image.shape == (target_height, target_width), "Shape mismatch that shouldn't have happened if you were on scipy > 0.13.0. " "Are you on scipy > 0.13.0?", ShapeError) return rescaled_image class RandomCrop(Transform): """Crop input to a given size. This is similar to torchvision.transforms.RandomCrop, except that it operates on numpy arrays instead of PIL images. If you do have a PIL image and wish to use this transform, consider applying `PILImage2NumPyArray` first. Warnings -------- If `output_image_shape` is larger than the image itself, the image is not cropped (along the relevant dimensions). """ def __init__(self, output_image_shape, **super_kwargs): """ Parameters ---------- output_image_shape : tuple or list or int Expected shape of the output image. Could be an integer, (say) 100, in which case it's interpreted as `(100, 100)`. Note that if the image shape along some (or all) dimension is smaller, say `(50, 200)`, the resulting output images will have the shape `(50, 100)`. super_kwargs : dict Keywords to the super class. """ super(RandomCrop, self).__init__(**super_kwargs) # Privates self._image_shape_cache = None # Publics output_image_shape = (output_image_shape, output_image_shape) \ if isinstance(output_image_shape, int) else tuple(output_image_shape) assert_(len(output_image_shape) == 2, "`output_image_shape` must be an integer or a tuple of length 2.", ValueError) self.output_image_shape = output_image_shape def clear_random_variables(self): self._image_shape_cache = None super(RandomCrop, self).clear_random_variables() def build_random_variables(self, height_leeway, width_leeway): if height_leeway > 0: self.set_random_variable('height_location', np.random.randint(low=0, high=height_leeway + 1)) if width_leeway > 0: self.set_random_variable('width_location', np.random.randint(low=0, high=width_leeway + 1)) def image_function(self, image): # Validate image shape if self._image_shape_cache is not None: assert_(self._image_shape_cache == image.shape, "RandomCrop works on multiple images simultaneously only " "if they have the same shape. Was expecting an image of " "shape {}, got one of shape {} instead." .format(self._image_shape_cache, image.shape), ShapeError) else: self._image_shape_cache = image.shape source_height, source_width = image.shape crop_height, crop_width = self.output_image_shape height_leeway = source_height - crop_height width_leeway = source_width - crop_width if height_leeway > 0: # Crop height height_location = self.get_random_variable('height_location', height_leeway=height_leeway, width_leeway=width_leeway) cropped = image[height_location:(height_location + crop_height), :] assert cropped.shape[0] == self.output_image_shape[0], "Well, shit." else: cropped = image if width_leeway > 0: # Crop width width_location = self.get_random_variable('width_location', height_leeway=height_leeway, width_leeway=width_leeway) cropped = cropped[:, width_location:(width_location + crop_width)] assert cropped.shape[1] == self.output_image_shape[1], "Well, shit." return cropped class RandomSizedCrop(Transform): """Extract a randomly sized crop from the image. The ratio of the sizes of the cropped and the original image can be limited within specified bounds along both axes. To resize back to a constant sized image, compose with `Scale`. """ def __init__(self, ratio_between=None, height_ratio_between=None, width_ratio_between=None, preserve_aspect_ratio=False, relative_target_aspect_ratio=None, **super_kwargs): """ Parameters ---------- ratio_between : tuple Specify the bounds between which to sample the crop ratio. This applies to both height and width if not overriden. Can be None if both height and width ratios are specified individually. height_ratio_between : tuple Specify the bounds between which to sample the vertical crop ratio. Can be None if `ratio_between` is not None. width_ratio_between : tuple Specify the bounds between which to sample the horizontal crop ratio. Can be None if `ratio_between` is not None. preserve_aspect_ratio : bool Whether to preserve aspect ratio. If both `height_ratio_between` and `width_ratio_between` are specified, the former is used if this is set to True. relative_target_aspect_ratio : float Specify the target aspect ratio (W x H) relative to the input image (i.e. by mapping the input image ratio to 1:1). For instance, if an image has the size 1024 (H) x 2048 (W), a relative target aspect ratio of 0.5 might yield images of size 1024 x 1024. Note that this only applies if `preserve_aspect_ratio` is set to False. super_kwargs : dict Keyword arguments for the super class. """ super(RandomSizedCrop, self).__init__(**super_kwargs) # Privates self._image_shape_cache = None # Publics height_ratio_between = tuple(height_ratio_between) \ if height_ratio_between is not None else tuple(ratio_between) width_ratio_between = tuple(width_ratio_between) \ if width_ratio_between is not None else tuple(ratio_between) assert_(height_ratio_between is not None, "`height_ratio_between` is not specified.", ValueError) assert_(width_ratio_between is not None, "`width_ratio_between` is not specified.", ValueError) self.height_ratio_between = height_ratio_between self.width_ratio_between = width_ratio_between self.preserve_aspect_ratio = preserve_aspect_ratio self.relative_target_aspect_ratio = relative_target_aspect_ratio def build_random_variables(self, image_shape): # Seed RNG np.random.seed() # Compute random variables source_height, source_width = image_shape height_ratio = np.random.uniform(low=self.height_ratio_between[0], high=self.height_ratio_between[1]) if self.preserve_aspect_ratio: width_ratio = height_ratio elif self.relative_target_aspect_ratio is not None: width_ratio = height_ratio * self.relative_target_aspect_ratio else: width_ratio = np.random.uniform(low=self.width_ratio_between[0], high=self.width_ratio_between[1]) crop_height = int(np.round(height_ratio * source_height)) crop_width = int(np.round(width_ratio * source_width)) height_leeway = source_height - crop_height width_leeway = source_width - crop_width # Set random variables if height_leeway > 0: self.set_random_variable('height_location', np.random.randint(low=0, high=height_leeway + 1)) if width_leeway > 0: self.set_random_variable('width_location', np.random.randint(low=0, high=width_leeway + 1)) self.set_random_variable('crop_height', crop_height) self.set_random_variable('crop_width', crop_width) self.set_random_variable('height_leeway', height_leeway) self.set_random_variable('width_leeway', width_leeway) def image_function(self, image): # Validate image shape if self._image_shape_cache is not None: assert_(self._image_shape_cache == image.shape, "RandomCrop works on multiple images simultaneously only " "if they have the same shape. Was expecting an image of " "shape {}, got one of shape {} instead." .format(self._image_shape_cache, image.shape), ShapeError) else: self._image_shape_cache = image.shape height_leeway = self.get_random_variable('height_leeway', image_shape=image.shape) width_leeway = self.get_random_variable('width_leeway', image_shape=image.shape) if height_leeway > 0: height_location = self.get_random_variable('height_location', image_shape=image.shape) crop_height = self.get_random_variable('crop_height', image_shape=image.shape) cropped = image[height_location:(height_location + crop_height), :] else: cropped = image if width_leeway > 0: width_location = self.get_random_variable('width_location', image_shape=image.shape) crop_width = self.get_random_variable('crop_width', image_shape=image.shape) cropped = cropped[:, width_location:(width_location + crop_width)] return cropped class RandomGammaCorrection(Transform): """Applies gamma correction [1] with a random gamma. This transform uses `skimage.exposure.adjust_gamma`, which requires the input be positive. References ---------- [1] https://en.wikipedia.org/wiki/Gamma_correction """ def __init__(self, gamma_between=(0.5, 2.), gain=1, **super_kwargs): """ Parameters ---------- gamma_between : tuple or list Specifies the range within which to sample gamma (uniformly). gain : int or float The resulting gamma corrected image is multiplied by this `gain`. super_kwargs : dict Keyword arguments for the superclass. """ super(RandomGammaCorrection, self).__init__(**super_kwargs) self.gamma_between = list(gamma_between) self.gain = gain def build_random_variables(self): np.random.seed() self.set_random_variable('gamma', np.random.uniform(low=self.gamma_between[0], high=self.gamma_between[1])) def image_function(self, image): gamma_adjusted = adjust_gamma(image, gamma=self.get_random_variable('gamma'), gain=self.gain) return gamma_adjusted class ElasticTransform(Transform): """Random Elastic Transformation.""" NATIVE_DTYPES = {'float32', 'float64'} PREFERRED_DTYPE = 'float32' def __init__(self, alpha, sigma, order=1, invert=False, **super_kwargs): self._initial_dtype = None super(ElasticTransform, self).__init__(**super_kwargs) self.alpha = alpha self.sigma = sigma self.order = order self.invert = invert def build_random_variables(self, **kwargs): # All this is done just once per batch (i.e. until `clear_random_variables` is called) np.random.seed() imshape = kwargs.get('imshape') # Build and scale random fields random_field_x = np.random.uniform(-1, 1, imshape) * self.alpha random_field_y = np.random.uniform(-1, 1, imshape) * self.alpha # Smooth random field (this has to be done just once per reset) sdx = gaussian_filter(random_field_x, self.sigma, mode='reflect') sdy = gaussian_filter(random_field_y, self.sigma, mode='reflect') # Make meshgrid x, y = np.meshgrid(np.arange(imshape[1]), np.arange(imshape[0])) # Make inversion coefficient _inverter = 1. if not self.invert else -1. # Distort meshgrid indices (invert if required) flow_y, flow_x = (y + _inverter * sdy).reshape(-1, 1), (x + _inverter * sdx).reshape(-1, 1) # Set random states self.set_random_variable('flow_x', flow_x) self.set_random_variable('flow_y', flow_y) def cast(self, image): if image.dtype not in self.NATIVE_DTYPES: self._initial_dtype = image.dtype image = image.astype(self.PREFERRED_DTYPE) return image def uncast(self, image): if self._initial_dtype is not None: image = image.astype(self._initial_dtype) self._initial_dtype = None return image def image_function(self, image): # Cast image to one of the native dtypes (one which that is supported by scipy) image = self.cast(image) # Take measurements imshape = image.shape # Obtain flows flows = self.get_random_variable('flow_y', imshape=imshape), \ self.get_random_variable('flow_x', imshape=imshape) # Map cooordinates from image to distorted index set transformed_image = map_coordinates(image, flows, mode='reflect', order=self.order).reshape(imshape) # Uncast image to the original dtype transformed_image = self.uncast(transformed_image) return transformed_image class AdditiveGaussianNoise(Transform): """Add gaussian noise to the input.""" def __init__(self, sigma, **super_kwargs): super(AdditiveGaussianNoise, self).__init__(**super_kwargs) self.sigma = sigma def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('noise', np.random.normal(loc=0, scale=self.sigma, size=kwargs.get('imshape'))) def image_function(self, image): image = image + self.get_random_variable('noise', imshape=image.shape) return image class RandomRotate(Transform): """Random 90-degree rotations.""" def __init__(self, **super_kwargs): super(RandomRotate, self).__init__(**super_kwargs) def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('k', np.random.randint(0, 4)) def image_function(self, image): return np.rot90(image, k=self.get_random_variable('k')) class RandomTranspose(Transform): """Random 2d transpose.""" def __init__(self, **super_kwargs): super(RandomTranspose, self).__init__(**super_kwargs) def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('do_transpose', np.random.uniform() > 0.5) def image_function(self, image): if self.get_random_variable('do_transpose'): image = np.transpose(image) return image class RandomFlip(Transform): """Random left-right or up-down flips.""" def __init__(self, allow_lr_flips=True, allow_ud_flips=True, **super_kwargs): super(RandomFlip, self).__init__(**super_kwargs) self.allow_lr_flips = allow_lr_flips self.allow_ud_flips = allow_ud_flips def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('flip_lr', np.random.uniform() > 0.5) self.set_random_variable('flip_ud', np.random.uniform() > 0.5) def image_function(self, image): if self.allow_lr_flips and self.get_random_variable('flip_lr'): image = np.fliplr(image) if self.allow_ud_flips and self.get_random_variable('flip_ud'): image = np.flipud(image) return image class CenterCrop(Transform): """ Crop patch of size `size` from the center of the image """ def __init__(self, size, **super_kwargs): super(CenterCrop, self).__init__(**super_kwargs) assert isinstance(size, (int, tuple)) self.size = (size, size) if isinstance(size, int) else size def image_function(self, image): h, w = image.shape th, tw = self.size if h > th: y1 = int(round((h - th) / 2.)) image = image[y1:y1 + th, :] if w > tw: x1 = int(round((w - tw) / 2.)) image = image[:, x1:x1 + tw] return image class BinaryMorphology(Transform): """ Apply a binary morphology operation on an image. Supported operations are dilation and erosion. """ def __init__(self, mode, num_iterations=1, morphology_kwargs=None, **super_kwargs): """ Parameters ---------- mode : {'dilate', 'erode'} Whether to dilate or erode. num_iterations : int Number of iterations to apply the operation for. morphology_kwargs: dict Keyword arguments to the morphology function (i.e. `scipy.ndimage.morphology.binary_erosion` or `scipy.ndimage.morphology.binary_erosion`) super_kwargs : dict Keyword arguments to the superclass. """ super(BinaryMorphology, self).__init__(**super_kwargs) # Validate and assign mode assert_(mode in ['dilate', 'erode'], "Mode must be one of ['dilate', 'erode']. Got {} instead.".format(mode), ValueError) self.mode = mode self.num_iterations = num_iterations self.morphology_kwargs = {} if morphology_kwargs is None else dict(morphology_kwargs) def image_function(self, image): if self.mode == 'dilate': transformed_image = binary_dilation(image, iterations=self.num_iterations, **self.morphology_kwargs) elif self.mode == 'erode': transformed_image = binary_erosion(image, iterations=self.num_iterations, **self.morphology_kwargs) else: raise ValueError # Cast transformed image to the right dtype and return return transformed_image.astype(image.dtype) class BinaryDilation(BinaryMorphology): """Apply a binary dilation operation on an image.""" def __init__(self, num_iterations=1, morphology_kwargs=None, **super_kwargs): super(BinaryDilation, self).__init__(mode='dilate', num_iterations=num_iterations, morphology_kwargs=morphology_kwargs, **super_kwargs) class BinaryErosion(BinaryMorphology): """Apply a binary erosion operation on an image.""" def __init__(self, num_iterations=1, morphology_kwargs=None, **super_kwargs): super(BinaryErosion, self).__init__(mode='erode', num_iterations=num_iterations, morphology_kwargs=morphology_kwargs, **super_kwargs) class FineRandomRotations(Transform): """ Random Rotation with random uniform angle distribution batch_function applies to rotation of input and label image Parameters ---------- angle_range : int maximum angle of rotation axes : tuple, default (1,2) assuming that channel axis is 0 pair of axis that define the 2d-plane of rotation mask_label : constant value that is used to pad the label images """ def __init__(self, angle_range, axes=(1,2), mask_label=0, **super_kwargs): super(FineRandomRotations, self).__init__(**super_kwargs) self.angle_range = angle_range self.axes = axes self.ml = mask_label def build_random_variables(self): np.random.seed() self.set_random_variable('angle', np.random.uniform(low=-self.angle_range, high=self.angle_range)) def batch_function(self, image): angle = self.get_random_variable('angle') return rotate(image[0], angle, axes=self.axes, reshape=False), \ rotate(image[1], angle, axes=self.axes, order=0, cval=self.ml, reshape=False) class RandomScaleSegmentation(Transform): """ Random Scale input and label image Parameters ---------- scale_range : tuple of floats defining (min, max) scales maximum angle of rotation resize : if True, image is cropped or padded to the original size pad_const: value used for constant padding """ def __init__(self, scale_range, resize=True, pad_const=0, **super_kwargs): super(RandomScaleSegmentation, self).__init__(**super_kwargs) self.scale_range = scale_range self.resize = resize self.pad_const = pad_const def build_random_variables(self): np.random.seed() self.set_random_variable('seg_scale', np.random.uniform(low=self.scale_range[0], high=self.scale_range[1])) def batch_function(self, image): scale = self.get_random_variable('seg_scale') input_image, segmentation = image image_shape = np.array(input_image.shape[1:]) if input_image.ndim == segmentation.ndim + 1: segmentation = segmentation[None] with catch_warnings(): simplefilter('ignore') img = np.stack([zoom(x, scale, order=3) for x in input_image]) seg = np.stack([zoom(x, scale, order=0) for x in segmentation]) new_shape = np.array(img.shape[1:]) if self.resize: if scale > 1.: # pad image to original size crop_l = (new_shape - image_shape) // 2 crop_r = new_shape - image_shape - crop_l cropping = [slice(None)] + [slice(c[0] if c[0] > 0 else None, -c[1] if c[1] > 0 else None) for c in zip(crop_l, crop_r)] img = img[cropping] seg = seg[cropping] else: # crop image to original size pad_l = (image_shape - new_shape) // 2 pad_r = image_shape - new_shape - pad_l padding = [(0,0)] + list(zip(pad_l, pad_r)) img = np.pad(img, padding, 'constant', constant_values=self.pad_const) seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const) return img, seg ================================================ FILE: inferno/io/transform/volume.py ================================================ import numpy as np import scipy from scipy.ndimage import zoom from scipy.ndimage.morphology import binary_dilation, binary_erosion from .base import Transform from ...utils.exceptions import assert_ class RandomFlip3D(Transform): def __init__(self, **super_kwargs): super(RandomFlip3D, self).__init__(**super_kwargs) def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('flip_lr', np.random.uniform() > 0.5) self.set_random_variable('flip_ud', np.random.uniform() > 0.5) self.set_random_variable('flip_z', np.random.uniform() > 0.5) def volume_function(self, volume): if self.get_random_variable('flip_lr'): volume = volume[:, :, ::-1].copy() if self.get_random_variable('flip_ud'): volume = volume[:, ::-1, :].copy() if self.get_random_variable('flip_z'): volume = volume[::-1, :, :].copy() return volume class RandomRot3D(Transform): def __init__(self, rot_range, p=0.125, reshape=False, order=0, mode='nearest', **super_kwargs): super(RandomRot3D, self).__init__(**super_kwargs) self.rot_range = rot_range self.p = p self.reshape = reshape self.order = order self.mode = mode def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('do_z', np.random.uniform() < self.p) self.set_random_variable('do_y', np.random.uniform() < self.p) self.set_random_variable('do_x', np.random.uniform() < self.p) self.set_random_variable('angle_z', np.random.uniform(-self.rot_range, self.rot_range)) self.set_random_variable('angle_y', np.random.uniform(-self.rot_range, self.rot_range)) self.set_random_variable('angle_x', np.random.uniform(-self.rot_range, self.rot_range)) def volume_function(self, volume): angle_z = self.get_random_variable('angle_z') angle_y = self.get_random_variable('angle_y') angle_x = self.get_random_variable('angle_x') # rotate along z-axis if self.get_random_variable('do_z'): volume = scipy.ndimage.interpolation.rotate(volume, angle_z, order=self.order, mode=self.mode, axes=(0, 1), reshape=self.reshape) # rotate along y-axis if self.get_random_variable('do_y'): volume = scipy.ndimage.interpolation.rotate(volume, angle_y, order=self.order, mode=self.mode, axes=(0, 2), reshape=self.reshape) # rotate along x-axis if self.get_random_variable('do_y'): volume = scipy.ndimage.interpolation.rotate(volume, angle_x, order=self.order, mode=self.mode, axes=(1, 2), reshape=self.reshape) return volume # TODO this is obsolete class AdditiveRandomNoise3D(Transform): """ Add gaussian noise to 3d volume Need to know input shape before application, but can be synchronized between different inputs (cf. `AdditiveNoise`) Arguments: shape: shape of input volumes std: standard deviation of gaussian super_kwargs: keyword arguments for `Transform` base class """ def __init__(self, shape, std, **super_kwargs): super(AdditiveRandomNoise3D, self).__init__(**super_kwargs) self.shape = shape self.std = float(std) def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('noise_vol', np.random.normal(loc=0.0, scale=self.std, size=self.shape)) def volume_function(self, volume): noise_vol = self.get_random_variable('noise_vol') return volume + noise_vol # TODO different options than gaussian class AdditiveNoise(Transform): """ Add noise to 3d volume Do NOT need to know input shape before application, but CANNOT be synchronized between different inputs (cf. `AdditiveRandomNoise`) Arguments: sigma: sigma for noise mode: mode of distribution (only gaussian supported for now) super_kwargs: keyword arguments for `Transform` base class """ def __init__(self, sigma, mode='gaussian', **super_kwargs): assert mode == 'gaussian' super().__init__(**super_kwargs) self.sigma = sigma # TODO check if volume is tensor and use torch functions in that case def tensor_function(self, volume): volume += np.random.normal(loc=0, scale=self.sigma, size=volume.shape) return volume class CentralSlice(Transform): def volume_function(self, volume): half_z = volume.shape[0] // 2 return volume[half_z:half_z + 1, ...] class VolumeCenterCrop(Transform): """ Crop patch of size `size` from the center of the volume """ def __init__(self, size, **super_kwargs): super().__init__(**super_kwargs) assert isinstance(size, (int, tuple)) self.size = (size, size, size) if isinstance(size, int) else size assert len(size) == 3 def volume_function(self, volume): h, w, d = volume.shape th, tw, td = self.size x1 = int(round((w - tw) / 2.)) y1 = int(round((h - th) / 2.)) z1 = int(round((d - td) / 2.)) return volume[x1:x1+tw, y1:y1+th, z1:z1+td] class VolumeAsymmetricCrop(Transform): """ Crop `crop_left` from the left borders and `crop_right` from the right borders """ def __init__(self, crop_left, crop_right, **super_kwargs): super(VolumeAsymmetricCrop, self).__init__(**super_kwargs) assert isinstance(crop_left, (list, tuple)) assert isinstance(crop_right, (list, tuple)) assert len(crop_left) == 3 assert len(crop_right) == 3 self.crop_left = crop_left self.crop_right = crop_right def volume_function(self, volume): x1, y1, z1 = self.crop_left x2, y2, z2 = (np.array(volume.shape) - np.array(self.crop_right)).astype('uint32') return volume[x1:x2, y1:y2, z1:z2] class Slices2Channels(Transform): """ Needed for training 2D network with slices above/below as additional channels For the input data transforms one dimension (x, y or z) into channels For the target data just takes the central slice and discards all the rest""" def __init__(self, num_channels, downsampling=1, **super_kwargs): super(Slices2Channels, self).__init__(**super_kwargs) self.channels = num_channels self.downsampling = downsampling def batch_function(self, batch): try: axis = batch[0].shape.index(self.channels) except ValueError: print("The axis has the shape of the desired channels number!") half = int(self.channels/2) new_input = np.moveaxis(batch[0], axis, 0) # take every nth slice to the both directions of the central slice indices = [] for i in range(self.channels): if i % self.downsampling == half % self.downsampling: indices.append(i) new_input = new_input[indices] # num_chan after - int (num_chan/(2*downsample)) * 2 + 1 new_target = np.moveaxis(batch[1], axis, 0) new_target = new_target[half] return (new_input, new_target) class RandomScale3D(Transform): """Scales a volume with a random zoom factor with spline interpolation of requested order""" def __init__(self, zoom_factor_range, interpolation_order=0, p=0.5, same_zoom=True, zoom_kwargs=None, **super_kwargs): """ Parameters ---------- zoom_factor_range : list or tuple The allowed range to sample zoom factors along the axes. interpolation_order : int Interpolation order for the spline interpolation. p : float Probability that the axis gets zoomed same_zoom: bool Apply the same zoom factor to all the axes zoom_kwargs : dict Keyword arguments for `scipy.ndimage.zoom`. super_kwargs : dict Keyword arguments for the superclass. """ super(RandomScale3D, self).__init__(**super_kwargs) assert_(len(zoom_factor_range) == 2, "`zoom_factor_range` must be a list or a tuple of length 2.", ValueError) self.min = zoom_factor_range[0] self.max = zoom_factor_range[1] self.interpolation_order = interpolation_order self.p = p self.same_zoom = same_zoom self.zoom_kwargs = {} if zoom_kwargs is None else dict(zoom_kwargs) def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('do_z', np.random.uniform() < self.p) self.set_random_variable('do_y', np.random.uniform() < self.p) self.set_random_variable('do_x', np.random.uniform() < self.p) self.set_random_variable('zoom_z', np.random.uniform(self.min, self.max)) self.set_random_variable('zoom_y', np.random.uniform(self.min, self.max)) self.set_random_variable('zoom_x', np.random.uniform(self.min, self.max)) def volume_function(self, volume): zoom_z = self.get_random_variable('zoom_z') \ if self.get_random_variable('do_z') else 1 zoom_y = self.get_random_variable('zoom_y') \ if self.get_random_variable('do_y') else 1 zoom_x = self.get_random_variable('zoom_x') \ if self.get_random_variable('do_x') else 1 if self.same_zoom: zoom_y, zoom_x = zoom_z, zoom_z zoomed_volume = zoom(volume, (zoom_z, zoom_y, zoom_x), order=self.interpolation_order, **self.zoom_kwargs) return zoomed_volume class RandomBinaryMorphology3D(Transform): """ Apply a random binary morphology operation (dilation or erosion). Allowed range of iteration number can be set. """ def __init__(self, p=0.5, num_iter_range=(1, 5), morphology_kwargs=None, **super_kwargs): """ Parameters ---------- p : float Probability that any operation is applied num_iter_range : list or tuple The allowed range of iteration number to apply the operation for. morphology_kwargs: dict Keyword arguments to the morphology function (i.e. `scipy.ndimage.morphology.binary_erosion` or `scipy.ndimage.morphology.binary_erosion`) super_kwargs : dict Keyword arguments to the superclass. """ super(RandomBinaryMorphology3D, self).__init__(**super_kwargs) assert_(len(num_iter_range) == 2, "`num_iter_range` must be a list or a tuple of length 2.", ValueError) self.p = p self.min_iter = num_iter_range[0] self.max_iter = num_iter_range[1] + 1 self.morphology_kwargs = {} if morphology_kwargs is None else dict(morphology_kwargs) def build_random_variables(self, **kwargs): np.random.seed() self.set_random_variable('do', np.random.uniform() < self.p) self.set_random_variable('erode', np.random.uniform() < 0.5) self.set_random_variable('iter_num', np.random.randint(self.min_iter, self.max_iter)) def volume_function(self, volume): do = self.get_random_variable('do') erode_mode = self.get_random_variable('erode') iter_num = self.get_random_variable('iter_num') if do: if erode_mode: transformed_volume = binary_erosion(volume, iterations=iter_num, **self.morphology_kwargs) else: transformed_volume = binary_dilation(volume, iterations=iter_num, **self.morphology_kwargs) volume = transformed_volume.astype(volume.dtype) return volume class CropPad2Divisible(Transform): """ Given the number, symmetrically crops/pads the volume for all dimensions to be divisible by this number. Used e.g. to feed input with any shape to models with pooling layers. The threshold of cropping vs padding can be specified. """ def __init__(self, divisor=16, crop_pad_threshold=0.2, mode='constant', padding_kwargs=None, **super_kwargs): """ Parameters ---------- divisor : int A number that all dimensions should be divisible by crop_pad_threshold : float When "division remainder to divisor" ratio is lower then this number, input volume will be cropped, otherwise - padded. Set to 0 to only pad and 1 to only crop. mode: ‘constant’, ‘edge’, ‘symmetric’, etc See all the possible modes in numpy.pad doc padding_kwargs: dict Keyword arguments to numpy.pad super_kwargs : dict Keyword arguments to the superclass. """ super(CropPad2Divisible, self).__init__(**super_kwargs) assert_(0 <= crop_pad_threshold <= 1, "threshold must be between 0 and 1 inclusive", ValueError) assert_(divisor % 2 == 0, "divisor must be an even number", ValueError) self.divisor = divisor self.crop_pad_threshold = crop_pad_threshold self.mode = mode self.padding_kwargs = {} if padding_kwargs is None else dict(padding_kwargs) def volume_function(self, volume): half_div = int(self.divisor/2) remainders = [axis % self.divisor for axis in volume.shape] to_pad = [remainder/self.divisor >= self.crop_pad_threshold for remainder in remainders] diffs = [(int(np.floor(remainder/2)), int(np.ceil(remainder/2))) for remainder in remainders] padding = [(half_div - diff[0], half_div - diff[1]) if pad else (0, 0) for diff, pad in zip(diffs, to_pad)] cropping = [slice(diff[0], -diff[1]) if not (pad or diff[1] == 0) else slice(None, None) for diff, pad in zip(diffs, to_pad)] volume = np.pad(volume, pad_width=padding, mode=self.mode, **self.padding_kwargs) volume = volume[cropping] return volume class CropPad2Size(Transform): """ Adjust the input volume to the given size: Symmetrically crops if input > size, symmetrically pads if input < size. """ def __init__(self, output_size, mode='constant', padding_kwargs=None, **super_kwargs): """ Parameters ---------- output_size : int, tuple or list The output size. If int, the same value is used for all axes mode: `constant`, `edge`, `symmetric`, etc See all the possible modes in numpy.pad doc padding_kwargs: dict Keyword arguments to numpy.pad super_kwargs : dict Keyword arguments to the superclass. """ super(CropPad2Size, self).__init__(**super_kwargs) self.output_size = output_size if isinstance(output_size, (list, tuple)) \ else (output_size, ) * 3 assert len(self.output_size) == 3, 'The size should be given for all the dimensions' self.mode = mode self.padding_kwargs = {} if padding_kwargs is None else dict(padding_kwargs) def volume_function(self, volume): difference = [inp - outp for inp, outp in zip(volume.shape, self.output_size)] to_pad = [diff < 0 for diff in difference] to_crop = [diff > 0 for diff in difference] diffs = [(int(np.floor(diff/2)), int(np.ceil(diff/2))) for diff in np.abs(difference)] padding = [(diff[0], diff[1]) if pad else (0, 0) for diff, pad in zip(diffs, to_pad)] cropping = [slice(diff[0], -diff[1]) if crop else slice(None, None) for diff, crop in zip(diffs, to_crop)] volume = np.pad(volume, pad_width=padding, mode=self.mode, **self.padding_kwargs) volume = volume[cropping] return volume ================================================ FILE: inferno/io/volumetric/__init__.py ================================================ from .volume import VolumeLoader, HDF5VolumeLoader, TIFVolumeLoader from .lazy_volume_loader import LazyHDF5VolumeLoader, LazyZarrVolumeLoader, LazyN5VolumeLoader ================================================ FILE: inferno/io/volumetric/lazy_volume_loader.py ================================================ import numpy as np import os import pickle from concurrent import futures # try to load io libraries (h5py and z5py) try: import h5py WITH_H5PY = True except ImportError: WITH_H5PY = False try: import z5py WITH_Z5PY = True except ImportError: WITH_Z5PY = False from ..core.base import SyncableDataset from ..core.base import IndexSpec from . import volumetric_utils as vu from ...utils import python_utils as pyu # TODO support h5py as well def filter_base_sequence(input_path, input_key, window_size, stride, filter_function, n_threads): with z5py.File(input_path, 'r') as f: ds = f[input_key] shape = list(ds.shape) sequence = vu.slidingwindowslices(shape=shape, window_size=window_size, strides=stride, shuffle=True, add_overhanging=True) def check_slice(slice_id, slice_): print("Checking slice_id", slice_id) data = ds[slice_] if filter_function(data): return None else: return slice_ with futures.ThreadPoolExecutor(n_threads) as tp: tasks = [tp.submit(check_slice, slice_id, slice_) for slice_id, slice_ in enumerate(sequence)] filtered_sequence = [t.result() for t in tasks] filtered_sequence = [seq for seq in filtered_sequence if seq is not None] return filtered_sequence class LazyVolumeLoaderBase(SyncableDataset): def __init__(self, dataset, window_size, stride, downsampling_ratio=None, padding=None, padding_mode='reflect', transforms=None, return_index_spec=False, name=None, data_slice=None, base_sequence=None): super(LazyVolumeLoaderBase, self).__init__() assert len(window_size) == dataset.ndim, "%i, %i" % (len(window_size), dataset.ndim) assert len(stride) == dataset.ndim # Validate transforms assert transforms is None or callable(transforms) self.name = name self.return_index_spec = return_index_spec self.dataset = dataset self.window_size = window_size self.stride = stride self.padding_mode = padding_mode self.transforms = transforms # slicing and padding self.data_slice = self.normalize_slice(data_slice) self.padding = padding # DataloaderIter should do the shuffling self.shuffle = False # compute the shape self.shape = self.get_shape() self._data_shape = tuple(dsl.stop - dsl.start for dsl in self.data_slice)\ if self.data_slice is not None else self.dataset.shape if downsampling_ratio is None: self.downsampling_ratio = [1] * self.dataset.ndim elif isinstance(downsampling_ratio, int): self.downsampling_ratio = [downsampling_ratio] * self.dataset.ndim elif isinstance(downsampling_ratio, (list, tuple)): assert len(downsampling_ratio) == self.dataset.ndim self.downsampling_ratio = list(downsampling_ratio) else: raise NotImplementedError if base_sequence is None: self.base_sequence = self.make_sliding_windows() else: self.base_sequence = self.load_base_sequence(base_sequence) @staticmethod def load_base_sequence(base_sequence): if isinstance(base_sequence, (list, tuple)): return base_sequence elif isinstance(base_sequence, str): assert os.path.exists(base_sequence) with open(base_sequence, 'rb') as f: base_sequence = pickle.load(f) return base_sequence else: raise ValueError("Unsupported base_sequence format, must be either listlike or str") def normalize_slice(self, data_slice): if data_slice is None: return None slice_ = tuple(slice(0 if sl.start is None else sl.start, sh if sl.stop is None else sl.stop) for sl, sh in zip(data_slice, self.dataset.shape)) if len(slice_) < self.dataset.ndim: slice_ = slice_ + tuple(slice(0, sh) for sh in self.dataset.shape[len(slice_):]) return slice_ # get the effective shape after slicing and / or padding def get_shape(self): if self.data_slice is None: shape = self.dataset.shape else: # get the shape from the data slice (don't support ellipses) shape = tuple(slice_.stop - slice_.start for slice_ in self.data_slice) if self.padding is not None: # TODO is this correct ??? shape = tuple(sh + sum(pad) for sh, pad in zip(shape, self.padding)) return shape def make_sliding_windows(self): return list(vu.slidingwindowslices(shape=list(self.shape), window_size=self.window_size, strides=self.stride, shuffle=self.shuffle, add_overhanging=True, ds=self.downsampling_ratio)) def __getitem__(self, index): # Casting to int would allow index to be IndexSpec objects. index = int(index) slices = self.base_sequence[index] slices_ = tuple(slices) # check if we have padding and if we need to pad if self.padding is not None: # get the start and stop positions in the dataset without padding starts = [sl.start - pad[0] for sl, pad in zip(slices_, self.padding)] stops = [sl.stop - pad[0] for sl, pad in zip(slices_, self.padding)] # check if we need to pad to the left pad_left = None if any(start < 0 for start in starts): pad_left = tuple(abs(start) if start < 0 else 0 for start in starts) starts = [max(0, start) for start in starts] # check if we need to pad to the right pad_right = None if any(stop > sh for stop, sh in zip(stops, self._data_shape)): pad_right = tuple(stop - sh if stop > sh else 0 for stop, sh in zip(stops, self._data_shape)) stops = [min(sh, stop) for sh, stop in zip(self._data_shape, stops)] # check if we need any paddingand if so calculate the padding width need_padding = pad_left is not None or pad_right is not None if need_padding: # check the pad width (left and right) that we need for this batch pad_left = (0,) * len(self.shape) if pad_left is None else pad_left pad_right = (0,) * len(self.shape) if pad_right is None else pad_right pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right)) # update the slicing slices_ = tuple(slice(start, stop) for start, stop in zip(starts, stops)) else: need_padding = False # if we have data-slices, we need to bring # the slices back to the volume space if self.data_slice is not None: slices_ = tuple(slice(sl.start + dsl.start, sl.stop + dsl.start) for sl, dsl in zip(slices_, self.data_slice)) # load the slice and pad if necessary sliced_volume = self.dataset[slices_] if need_padding: sliced_volume = np.pad(sliced_volume, pad_width=pad_width, mode=self.padding_mode) if self.transforms is None: transformed = sliced_volume else: transformed = self.transforms(sliced_volume) if self.return_index_spec: return transformed, IndexSpec(index=index, base_sequence_at_index=slices) else: return transformed def clone(self, dataset=None, transforms=None, name=None): # Make sure the dataset shapes check out assert dataset.shape == self.dataset.shape # Make a new instance (without initializing) new = type(self).__new__(type(self)) # Update dictionary to initialize new_dict = dict(self.__dict__) if dataset is not None: new_dict.update({'dataset': dataset}) if transforms is not None: new_dict.update({'transforms': transforms}) if name is not None: new_dict.update({'name': name}) new.__dict__.update(new_dict) return new def __repr__(self): return "{}(shape={}, name={})".format(type(self).__name__, self.dataset.shape, self.name) # baseclass for hdf5, zarr or n5 volume loaders class LazyVolumeLoader(LazyVolumeLoaderBase): def __init__(self, file_impl, path, path_in_file=None, data_slice=None, transforms=None, name=None, **slicing_config): if isinstance(path, dict): assert name is not None assert name in path self.path = path.get(name) elif isinstance(path, str): assert os.path.exists(path), path self.path = path else: raise NotImplementedError("Not implemented for type %s" % type(path)) if isinstance(path_in_file, dict): assert name is not None assert name in path_in_file self.path_in_file = path_in_file.get(name) elif isinstance(path_in_file, str): self.path_in_file = path_in_file elif path_in_file is None: self.path_in_file = None else: raise NotImplementedError if data_slice is None or isinstance(data_slice, (str, list, tuple)): data_slice = vu.parse_data_slice(data_slice) elif isinstance(data_slice, dict): assert name is not None assert name in data_slice data_slice = vu.parse_data_slice(data_slice.get(name)) else: raise NotImplementedError self.validate_data_slice(data_slice) slicing_config_for_name = pyu.get_config_for_name(slicing_config, name) assert 'window_size' in slicing_config_for_name assert 'stride' in slicing_config_for_name self.file_ = file_impl(self.path, mode='r') # Initialize superclass with the volume super(LazyVolumeLoader, self).__init__(dataset=self.file_[self.path_in_file], name=name, transforms=transforms, data_slice=data_slice, **slicing_config_for_name) # we do not support step in the dataslice def validate_data_slice(self, data_slice): if data_slice is not None: assert all(sl.step in (None, 1) for sl in data_slice), "Complicated step is not supported" class LazyHDF5VolumeLoader(LazyVolumeLoader): def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=None, name=None, **slicing_config): assert WITH_H5PY, "Need h5py to load volume from hdf5 file." super(LazyHDF5VolumeLoader, self).__init__(file_impl=h5py.File, path=path, path_in_file=path_in_h5_dataset, data_slice=data_slice, transforms=transforms, name=name, **slicing_config) # this is not pythonic, but we need to close the h5py file def __del__(self): self.file_.close() class LazyN5VolumeLoader(LazyVolumeLoader): def __init__(self, path, path_in_file=None, data_slice=None, transforms=None, name=None, **slicing_config): assert WITH_Z5PY, "Need z5py to load volume from N5 file." assert slicing_config.get('downsampling_ratio', None) is None,\ "Downsampling is not supported by z5py based loaderes" super(LazyN5VolumeLoader, self).__init__(file_impl=z5py.N5File, path=path, path_in_file=path_in_file, data_slice=data_slice, transforms=transforms, name=name, **slicing_config) class LazyZarrVolumeLoader(LazyVolumeLoader): def __init__(self, path, path_in_file=None, data_slice=None, transforms=None, name=None, **slicing_config): assert WITH_Z5PY, "Need z5py to load volume from zarr file." assert slicing_config.get('downsampling_ratio', None) is None,\ "Downsampling is not supported by z5py based loaderes" super(LazyZarrVolumeLoader, self).__init__(file_impl=z5py.ZarrFile, path=path, path_in_file=path_in_file, data_slice=data_slice, transforms=transforms, name=name, **slicing_config) ================================================ FILE: inferno/io/volumetric/volume.py ================================================ import numpy as np import os import skimage.io from ..core.base import SyncableDataset from ..core.base import IndexSpec from . import volumetric_utils as vu from ...utils import io_utils as iou from ...utils import python_utils as pyu from ...utils.exceptions import assert_, ShapeError class VolumeLoader(SyncableDataset): """ Loader for in-memory volumetric data. Parameters ---------- volume: np.ndarray the volumetric data window_size: list or tuple size of the (3d) sliding window used for iteration stride: list or tuple stride of the (3d) sliding window used for iteration downsampling_ratio: list or tuple (default: None) factor by which the data is downsampled (no downsapling by default) padding: list (default: None) padding for data, follows np.pad syntax padding_mode: str (default: 'reflect') padding mode as in np.pad transforms: callable (default: None) transforms applied on each batch loaded from volume return_index_spec: bool (default: False) whether to return the index spec for each batch name: str (default: None) name of this volume is_multichannel: bool (default: False) is this a multichannel volume? sliding window is NOT applied to channel dimension """ def __init__(self, volume, window_size, stride, downsampling_ratio=None, padding=None, padding_mode='reflect', transforms=None, return_index_spec=False, name=None, is_multichannel=False): super(VolumeLoader, self).__init__() # Validate volume assert isinstance(volume, np.ndarray), str(type(volume)) # Validate window size and stride if is_multichannel: assert_(len(window_size) + 1 == volume.ndim, "%i, %i" % (len(window_size), volume.ndim), ShapeError) assert_(len(stride) + 1 == volume.ndim, exception_type=ShapeError) # TODO implemnent downsampling and padding for multi-channel volume assert_(downsampling_ratio is None, exception_type=NotImplementedError) assert_(padding is None, exception_type=NotImplementedError) else: assert_(len(window_size) == volume.ndim, "%i, %i" % (len(window_size), volume.ndim), ShapeError) assert_(len(stride) == volume.ndim, exception_type=ShapeError) # Validate transforms assert_(transforms is None or callable(transforms)) self.name = name self.return_index_spec = return_index_spec self.volume = volume self.window_size = window_size self.stride = stride self.padding_mode = padding_mode self.is_multichannel = is_multichannel self.transforms = transforms # DataloaderIter should do the shuffling self.shuffle = False ndim = self.volume.ndim - 1 if is_multichannel else self.volume.ndim if downsampling_ratio is None: self.downsampling_ratio = [1] * ndim elif isinstance(downsampling_ratio, int): self.downsampling_ratio = [downsampling_ratio] * self.volume.ndim elif isinstance(downsampling_ratio, (list, tuple)): assert_(len(downsampling_ratio) == self.volume.ndim, exception_type=ShapeError) self.downsampling_ratio = list(downsampling_ratio) else: raise NotImplementedError if padding is None: self.padding = [[0, 0]] * ndim else: self.padding = padding self.pad_volume() self.base_sequence = self.make_sliding_windows() def pad_volume(self, padding=None): padding = self.padding if padding is None else padding if padding is None: return self.volume else: #for symmertic padding only one int can be passed for each axis assert_(all(isinstance(pad, (int, tuple, list)) for pad in self.padding),\ "Expect int or iterable", TypeError) self.padding = [[pad, pad] if isinstance(pad, int) else pad for pad in self.padding] self.volume = np.pad(self.volume, pad_width=self.padding, mode=self.padding_mode) return self.volume def make_sliding_windows(self): shape = self.volume.shape[1:] if self.is_multichannel else self.volume.shape return list(vu.slidingwindowslices(shape=list(shape), window_size=self.window_size, strides=self.stride, shuffle=self.shuffle, add_overhanging=True, ds=self.downsampling_ratio)) def __getitem__(self, index): # Casting to int would allow index to be IndexSpec objects. index = int(index) slices = self.base_sequence[index] if self.is_multichannel: slices = (slice(None),) + tuple(slices) sliced_volume = self.volume[tuple(slices)] if self.transforms is None: transformed = sliced_volume else: transformed = self.transforms(sliced_volume) if self.return_index_spec: return transformed, IndexSpec(index=index, base_sequence_at_index=slices) else: return transformed def clone(self, volume=None, transforms=None, name=None): # Make sure the volume shapes check out assert_(volume.shape == self.volume.shape, exception_type=ShapeError) # Make a new instance (without initializing) new = type(self).__new__(type(self)) # Update dictionary to initialize new_dict = dict(self.__dict__) if volume is not None: new_dict.update({'volume': volume}) if transforms is not None: new_dict.update({'transforms': transforms}) if name is not None: new_dict.update({'name': name}) new.__dict__.update(new_dict) return new def __repr__(self): return "{}(shape={}, name={})".format(type(self).__name__, self.volume.shape, self.name) class HDF5VolumeLoader(VolumeLoader): """ Loader for volumes stored in hdf5, zarr or n5. Zarr and n5 are file formats very similar to hdf5, but use the regular filesystem to store data instead of a filesystem in a file as hdf5. The file type will be infered from the extension: .hdf5, .h5 and .hdf map to hdf5 .n5 maps to n5 .zr and .zarr map to zarr It will fail for other extensions. Parameters ---------- path: str path to file path_in_h5_dataset: str (default: None) path in file data_slice: slice (default: None) slice loaded from dataset transforms: callable (default: None) transforms applied on each batch loaded from volume name: str (default: None) name of this volume slicing_config: kwargs keyword arguments for base class `VolumeLoader` """ @staticmethod def is_h5(file_path): ext = os.path.splitext(file_path)[1].lower() if ext in ('.h5', '.hdf', '.hdf5'): return True elif ext in ('.zarr', '.zr', '.n5'): return False else: raise RuntimeError("Could not infer volume type for file extension %s" % ext) def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=None, name=None, **slicing_config): if isinstance(path, dict): assert name is not None assert name in path self.path = path.get(name) elif isinstance(path, str): assert os.path.exists(path), path self.path = path else: raise NotImplementedError if isinstance(path_in_h5_dataset, dict): assert name is not None assert name in path_in_h5_dataset self.path_in_h5_dataset = path_in_h5_dataset.get(name) elif isinstance(path_in_h5_dataset, str): self.path_in_h5_dataset = path_in_h5_dataset elif path_in_h5_dataset is None: self.path_in_h5_dataset = None else: raise NotImplementedError # get the dataslice if data_slice is None or isinstance(data_slice, (str, list)): self.data_slice = vu.parse_data_slice(data_slice) elif isinstance(data_slice, dict): assert name is not None assert name in data_slice self.data_slice = vu.parse_data_slice(data_slice.get(name)) else: raise NotImplementedError slicing_config_for_name = pyu.get_config_for_name(slicing_config, name) # adapt data-slice if this is a multi-channel volume (slice is not applied to channel dimension) if self.data_slice is not None and slicing_config_for_name.get('is_multichannel', False): self.data_slice = (slice(None),) + self.data_slice assert 'window_size' in slicing_config_for_name, str(slicing_config_for_name) assert 'stride' in slicing_config_for_name # Read in volume from file (can be hdf5, n5 or zarr) if self.is_h5(self.path): volume = iou.fromh5(self.path, self.path_in_h5_dataset, dataslice=self.data_slice) else: volume = iou.fromz5(self.path, self.path_in_h5_dataset, dataslice=self.data_slice) # Initialize superclass with the volume super(HDF5VolumeLoader, self).__init__(volume=volume, name=name, transforms=transforms, **slicing_config_for_name) class TIFVolumeLoader(VolumeLoader): """Loader for volumes stored in .tif files.""" def __init__(self, path, data_slice=None, transforms=None, name=None, **slicing_config): """ Parameters ---------- path : str Path to the volume. transforms : callable Transforms to apply on the read volume. slicing_config : dict Dictionary specifying the sliding window. Must contain keys 'window_size' and 'stride'. """ if isinstance(path, dict): assert name in path.keys() assert os.path.exists(path.get(name)) self.path = path.get(name) elif isinstance(path, str): assert os.path.exists(path) self.path = path else: raise NotImplementedError assert 'window_size' in slicing_config assert 'stride' in slicing_config if data_slice is None or isinstance(data_slice, (str, list)): self.data_slice = vu.parse_data_slice(data_slice) elif isinstance(data_slice, dict): assert name is not None assert name in data_slice self.data_slice = vu.parse_data_slice(data_slice.get(name)) else: raise NotImplementedError # Read in volume from file volume = skimage.io.imread(self.path) # and slice it volume = volume[self.data_slice] if self.data_slice is not None else volume # Initialize superclass with the volume super(TIFVolumeLoader, self).__init__(volume=volume, transforms=transforms, **slicing_config) ================================================ FILE: inferno/io/volumetric/volumetric_utils.py ================================================ import random import itertools as it def slidingwindowslices(shape, window_size, strides, ds=1, shuffle=True, rngseed=None, dataslice=None, add_overhanging=True): # only support lists or tuples for shape, window_size and strides assert isinstance(shape, (list, tuple)) assert isinstance(window_size, (list, tuple)), "%s" % (str(type(window_size))) assert isinstance(strides, (list, tuple)) dim = len(shape) assert len(window_size) == dim assert len(strides) == dim # check for downsampling assert isinstance(ds, (list, tuple, int)) if isinstance(ds, int): ds = [ds] * dim assert len(ds) == dim # Seed RNG if a seed is provided if rngseed is not None: random.seed(rngseed) # sliding windows in one dimenstion def dimension_window(start, stop, wsize, stride, dimsize, ds_dim): starts = range(start, stop + 1, stride) slices = [slice(st, st + wsize, ds_dim) for st in starts if st + wsize <= dimsize] # add an overhanging window at the end if the windoes # do not fit and `add_overhanging` if slices[-1].stop != dimsize and add_overhanging: slices.append(slice(dimsize - wsize, dimsize, ds_dim)) if shuffle: random.shuffle(slices) return slices # determine adjusted start and stop coordinates if we have a dataslice # otherwise predict the whole volume if dataslice is not None: assert len(dataslice) == dim, "Dataslice must be a tuple with len = data dimension." starts = [0 if sl.start is None else sl.start for sl in dataslice] stops = [sh - wsize if sl.stop is None else sl.stop - wsize for sl, wsize, sh in zip(dataslice, window_size, shape)] else: starts = dim * [0] stops = [dimsize - wsize if wsize != dimsize else dimsize for dimsize, wsize in zip(shape, window_size)] assert all(stp > strt for strt, stp in zip(starts, stops)),\ "%s, %s" % (str(starts), str(stops)) nslices = [dimension_window(start, stop, wsize, stride, dimsize, ds_dim) for start, stop, wsize, stride, dimsize, ds_dim in zip(starts, stops, window_size, strides, shape, ds)] return it.product(*nslices) # This code is legacy af, don't judge # Define a sliding window iterator (this time, more readable than a wannabe one-liner) def slidingwindowslices_depr(shape, nhoodsize, stride=1, ds=1, window=None, ignoreborder=True, shuffle=True, rngseed=None, startmins=None, startmaxs=None, dataslice=None): """ Returns a generator yielding (shuffled) sliding window slice objects. :type shape: int or list of int :param shape: Shape of the input data :type nhoodsize: int or list of int :param nhoodsize: Window size of the sliding window. :type stride: int or list of int :param stride: Stride of the sliding window. :type shuffle: bool :param shuffle: Whether to shuffle the iterator. """ # Determine dimensionality of the data datadim = len(shape) # Parse window if window is None: window = ['x'] * datadim else: assert len(window) == datadim, \ "Window must have the same length as the number of data dimensions." # Parse nhoodsize and stride nhoodsize = [nhoodsize, ] * datadim if isinstance(nhoodsize, int) else nhoodsize stride = [stride, ] * datadim if isinstance(stride, int) else stride ds = [ds, ] * datadim if isinstance(ds, int) else ds # Seed RNG if a seed is provided if rngseed is not None: random.seed(rngseed) # Define a function that gets a 1D slice def _1Dwindow(startmin, startmax, nhoodsize, stride, ds, seqsize, shuffle): starts = range(startmin, startmax + 1, stride) if ignoreborder: slices = [slice(st, st + nhoodsize, ds) for st in starts if st + nhoodsize <= seqsize] else: slices = [slice(st, ((st + nhoodsize) if st + nhoodsize <= seqsize else None), ds) for st in starts] if shuffle: random.shuffle(slices) return slices # Get window start limits if dataslice is None: startmins = [0, ] * datadim if startmins is None else startmins startmaxs = [shp - nhoodsiz for shp, nhoodsiz in zip(shape, nhoodsize)] \ if startmaxs is None else startmaxs else: assert len(dataslice) == datadim, \ "Dataslice must be a tuple with len = data dimension." startmins = [sl.start for sl in dataslice] startmaxs = [sl.stop - nhoodsiz for sl, nhoodsiz in zip(dataslice, nhoodsize)] def _to_list(x): if not isinstance(x, (list, tuple)): return list(x) else: return x # The final iterator is going to be a cartesian product of the lists in nslices nslices = [_1Dwindow(startmin, startmax, nhoodsiz, st, dsample, datalen, shuffle) if windowspec == 'x' else [slice(ws, ws + 1) for ws in _to_list(windowspec)] for startmin, startmax, datalen, nhoodsiz, st, windowspec, dsample in zip(startmins, startmaxs, shape, nhoodsize, stride, window, ds)] return it.product(*nslices) def parse_data_slice(data_slice): """Parse a dataslice as a list of slice objects.""" if data_slice is None: return data_slice elif isinstance(data_slice, (list, tuple)) and \ all([isinstance(_slice, slice) for _slice in data_slice]): return tuple(data_slice) else: assert isinstance(data_slice, str) # Get rid of whitespace data_slice = data_slice.replace(' ', '') # Split by commas dim_slices = data_slice.split(',') # Build slice objects slices = [] for dim_slice in dim_slices: indices = dim_slice.split(':') if len(indices) == 2: start, stop, step = indices[0], indices[1], None elif len(indices) == 3: start, stop, step = indices else: raise RuntimeError # Convert to ints start = int(start) if start != '' else None stop = int(stop) if stop != '' else None step = int(step) if step is not None and step != '' else None # Build slices slices.append(slice(start, stop, step)) return tuple(slices) ================================================ FILE: inferno/trainers/__init__.py ================================================ from . import basic from . import callbacks from . basic import Trainer __all__ = ['basic','callbacks','Trainer'] ================================================ FILE: inferno/trainers/basic.py ================================================ from datetime import datetime from inspect import signature import os import shutil # These are fetched from globals, they're not unused # noinspection PyUnresolvedReferences import dill # noinspection PyUnresolvedReferences import pickle import torch from numpy import inf from torch.utils.data import DataLoader from torch.nn.parallel.data_parallel import data_parallel from .callbacks.logging.base import Logger from .callbacks.logging import get_logger from ..utils import train_utils as tu from ..utils import python_utils as pyu from ..utils import torch_utils as thu from ..extensions import metrics from ..extensions import optimizers from ..extensions import criteria from .callbacks import CallbackEngine from .callbacks import Console from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError # NOTE for distributed training, we might also need # from apex.parallel import DistributedDataParallel as DDP # but I don't know where exactly to put it. try: from apex import amp except ImportError: amp = None class Trainer(object): """A basic trainer. Given a torch model, this class encapsulates the training and validation loops, checkpoint creation, logging, CPU <-> GPU transfers and managing data-loaders. In addition, this class interacts with the callback engine (found at `inferno.trainers.callbacks.base.CallbackEngine`), which manages callbacks at certain preset events. Notes ----- Logging is implemented as a special callback, in the sense that it's jointly managed by the this class and the callback engine. This is primarily because general callbacks are not intended to be serializable, but not being able to serialize the logger is a nuisance. """ def __init__(self, model=None): """ Parameters ---------- model : torch.nn.Module Torch model to bind to. """ # Privates # Core self._model = None self._optimizer = None self._criterion = None self._retain_graph = False self._backprop_every = 1 # Metric evaluation self._metric = None self._evaluate_metric_every = None self._metric_evaluation_externally_triggered = False self._last_metric_evaluated_at_epoch = 0 # Logging self._logger = None self._last_logged = {} self._log_directory = {} # Data logistics self._loaders = {} self._loader_iters = {} self._loader_specs = {} # Iteration and epoch book-keeping self._iteration_count = 0 self._epoch_count = 0 self._batch_count = 0 self._current_mode = 'train' # GPU and dtype business self._use_cuda = False self._dtype = 'float' self._devices = None self._base_device_ordinal = None # Validation self._save_at_best_validation_score = False self._best_validation_score = None self._is_iteration_with_best_validation_score = False self._validate_every = None self._num_validation_iterations = None self._target_batch_dim = 0 self._validation_criterion = None # We should exclude the zero-th epoch from validation self._last_validated_at_epoch = 0 self._last_validated_at_iteration = 0 # This is to allow a callback to trigger a validation by setting # trainer.validate_now = True self._validation_externally_triggered = False # Checkpointing self._save_every = None self._save_to_directory = None self._pickle_module = 'pickle' # Defaults for file names self._checkpoint_filename = 'checkpoint.pytorch' self._best_checkpoint_filename = 'best_checkpoint.pytorch' # Nothing to save at epoch 0 self._last_saved_at_epoch = 0 # This is to allow a callback to trigger a save by setting trainer.save_now = True self._save_externally_triggered = False # Stopping conditions self._max_num_iterations = None self._max_num_epochs = None # Callbacks and states self._callback_engine = CallbackEngine().bind_trainer(self) self._state = {} # Print console self._console = Console() # Train with mixed precision, only works # if we have apex self._mixed_precision = False self._apex_opt_level = 'O1' # Public if model is not None: self.model = model @property def mixed_precision(self): return self._mixed_precision # this needs to be called after model and optimizer are set @mixed_precision.setter def mixed_precision(self, mp): if mp: assert_(amp is not None, "Cannot use mixed precision training without apex library", RuntimeError) assert_(self.model is not None and self._optimizer is not None, "Model and optimizer need to be set before activating mixed precision", RuntimeError) # in order to support BCE loss amp.register_float_function(torch, 'sigmoid') # For now, we don't allow to set 'keep_batchnorm' and 'loss_scale' self.model, self._optimizer = amp.initialize(self.model, self._optimizer, opt_level=self._apex_opt_level, keep_batchnorm_fp32=None) self._mixed_precision = mp @property def apex_opt_level(self): return self._apex_opt_level @apex_opt_level.setter def apex_opt_level(self, opt_level): assert_(opt_level in ('O0', 'O1', 'O2', 'O3'), "Invalid optimization level", ValueError) self._apex_opt_level = opt_level @property def console(self): """Get the current console.""" return self._console def set_console(self, console): assert_(isinstance(console, Console), "`console` must be a Console object.", TypeError) self._console = console return self def quiet(self): self.console.toggle_progress(False) return self @property def callbacks(self): """Gets the callback engine.""" return self._callback_engine def register_callback(self, callback, trigger='auto', **callback_kwargs): """ Registers a callback with the internal callback engine. Parameters ---------- callback : type or callable Callback to register. trigger : str Specify the event that triggers the callback. Leave at 'auto' to have the callback-engine figure out the triggers. See `inferno.training.callbacks.base.CallbackEngine` documentation for more on this. callback_kwargs : dict If `callback` is a type, initialize an instance with these keywords to the __init__ method. Returns ------- Trainer self. """ if isinstance(callback, type): callback = callback(**callback_kwargs) self._callback_engine.register_callback(callback, trigger=trigger) return self @property def model(self): """Gets the model.""" assert_(self._model is not None, "Model is not defined yet.", NotSetError) return self._model @model.setter def model(self, value): self.bind_model(value) def bind_model(self, model): """ Binds a model to the trainer. Equivalent to setting model. Parameters ---------- model : torch.nn.Module Model to bind. Returns ------- Trainer self. """ assert_(isinstance(model, torch.nn.Module), "Model must be a torch.nn.Module.", NotTorchModuleError) self._model = model # Transfer model to GPU if required if self._use_cuda: self._model.cuda() return self @property def model_is_defined(self): return self._model is not None @property def retain_graph(self): return self._retain_graph @retain_graph.setter def retain_graph(self, value): assert isinstance(value, bool) self._retain_graph = value @property def backprop_every(self): return self._backprop_every @backprop_every.setter def backprop_every(self, value): self.set_backprop_every(value) def set_backprop_every(self, num_steps): """ Set frequency of backpropagation. To use in cases of small batch sizes. Parameters ---------- num_steps : number of steps (iterations/batches) to backprop after Returns ------- Trainer self """ assert isinstance(num_steps, int) self._backprop_every = num_steps return self @property def optimizer(self): """Gets the optimizer.""" assert_(self._optimizer is not None, "Optimizer is not set yet.", NotSetError) return self._optimizer @optimizer.setter def optimizer(self, value): if isinstance(value, str) or callable(value): self.build_optimizer(value) elif isinstance(value, dict): self.build_optimizer(**value) else: raise NotImplementedError @property def optimizer_is_defined(self): return self._optimizer is not None def build_optimizer(self, method, param_groups=None, **kwargs): """ Builds the optimizer for training. Parameters ---------- method : str or callable or torch.optim.Optimizer Name of the optimizer when str, handle to the optimizer class when callable, or a torch.optim.Optimizer instance. If a name is provided, this method looks for the optimizer in `torch.optim` module first and in inferno.extensions.optimizers second. param_groups : list of dict Specifies the parameter group. Defaults to model.parameters() if None. kwargs : dict Keyword arguments to the optimizer. Returns ------- Trainer self. Raises ------ AssertionError if optimizer is not found NotImplementedError if method is not str or callable. """ if isinstance(method, str): optimizer_class = getattr(torch.optim, method, None) if optimizer_class is None: # Look for optimizer in extensions optimizer_class = getattr(optimizers, method, None) assert optimizer_class is not None, "Optimizer {} not found.".format(method) elif callable(method) and isinstance(method, type): optimizer_class = method elif isinstance(method, torch.optim.Optimizer): self._optimizer = method return self else: raise NotImplementedError param_groups = self.model.parameters() if param_groups is None else param_groups self._optimizer = optimizer_class(param_groups, **kwargs) return self @property def criterion(self): """Gets the loss criterion.""" assert_(self._criterion is not None, "Criterion is not set yet.", NotSetError) return self._criterion @criterion.setter def criterion(self, value): if isinstance(value, str) or callable(value): self.build_criterion(value) elif isinstance(value, dict): self.build_criterion(**value) else: raise RuntimeError(f"Criterion can either be set to a string, callable or a dict. " f"Got {type(value).__name__} instead.") def build_criterion(self, method, **kwargs): """ Builds the loss criterion for training. Parameters ---------- method : str or callable or torch.nn.Module Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in `torch.nn`. kwargs : dict Keyword arguments to the criterion class' constructor if applicable. Returns ------- Trainer self. Raises ------ AssertionError if criterion is not found. NotImplementedError if method is neither a str nor a callable. """ if isinstance(method, str): # Look for criteria in torch criterion_class = getattr(torch.nn, method, None) if criterion_class is None: # Look for it in extensions criterion_class = getattr(criteria, method, None) assert criterion_class is not None, "Criterion {} not found.".format(method) elif callable(method) and isinstance(method, type): criterion_class = method elif isinstance(method, torch.nn.Module): self._criterion = method return self else: raise NotImplementedError self._criterion = criterion_class(**kwargs) # Transfer criterion to GPU if required. This is necessary for e.g. weighted loss, # where the weight is registered as a buffer. # The criterion is to be cuda'ed only if the model is on CUDA (self._use_cuda) and # the base_device is not CPU (ordinal -1). if hasattr(self, '_base_device_ordinal'): # This is to not break old checkpoints base_device_ordinal = self._base_device_ordinal else: base_device_ordinal = None if self._use_cuda and base_device_ordinal != 1: self._criterion.cuda() return self @property def criterion_is_defined(self): return self._criterion is not None @property def validation_criterion(self): if self._validation_criterion is None: return self.criterion else: return self._validation_criterion @validation_criterion.setter def validation_criterion(self, value): if isinstance(value, str) or callable(value): self.build_validation_criterion(value) elif isinstance(value, dict): self.build_validation_criterion(**value) else: raise RuntimeError(f"Validation criterion can either be set to a string, callable " f"or a dict. Got {type(value).__name__} instead.") def build_validation_criterion(self, method, **kwargs): """ Builds the loss criterion for validation. Parameters ---------- method : str or callable or torch.nn.Module Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in `torch.nn`. kwargs : dict Keyword arguments to the criterion class' constructor if applicable. Returns ------- Trainer self. Raises ------ AssertionError if criterion is not found. NotImplementedError if method is neither a str nor a callable. """ if isinstance(method, str): # Look for criteria in torch criterion_class = getattr(torch.nn, method, None) if criterion_class is None: # Look for it in extensions criterion_class = getattr(criteria, method, None) assert criterion_class is not None, "Criterion {} not found.".format(method) elif callable(method) and isinstance(method, type): criterion_class = method elif isinstance(method, torch.nn.Module): self._validation_criterion = method return self else: raise NotImplementedError self._validation_criterion = criterion_class(**kwargs) # Transfer criterion to GPU if required. This is necessary for e.g. weighted loss, # where the weight is registered as a buffer. # The criterion is to be cuda'ed only if the model is on CUDA (self._use_cuda) and # the base_device is not CPU (ordinal -1). if hasattr(self, '_base_device_ordinal'): # This is to not break old checkpoints base_device_ordinal = self._base_device_ordinal else: base_device_ordinal = None if self._use_cuda and base_device_ordinal != 1: self._validation_criterion.cuda() return self def validation_criterion_is_train_criterion(self, yes=True): if yes: # This will cause the property to return train criterion self._validation_criterion = None return self @property def validation_criterion_is_defined(self): return self._validation_criterion is not None @property def metric(self): """Gets the evaluation metric.""" assert_(self._metric is not None, "Metric is not set yet.", NotSetError) return self._metric @metric.setter def metric(self, value): if callable(value) or isinstance(value, str): self.build_metric(value) else: raise NotImplementedError @property def evaluating_metric_every(self): return self._evaluate_metric_every def evaluate_metric_every(self, frequency): """ Set frequency of metric evaluation __during training__ (and not during validation). Parameters ---------- frequency : inferno.utils.train_utils.Frequency or str or tuple or list or int Metric evaluation frequency. If str, it could be (say) '10 iterations' or '1 epoch'. If tuple (or list), it could be (10, 'iterations') or (1, 'epoch'). If int (say 10), it's interpreted as (10, 'iterations'). Returns ------- Trainer self """ self._evaluate_metric_every = tu.Frequency.build_from(frequency, priority='iterations') assert self._evaluate_metric_every.is_consistent return self @property def evaluate_metric_now(self): if self._metric_evaluation_externally_triggered: # Reset trigger self._metric_evaluation_externally_triggered = False return True elif self._evaluate_metric_every is None: # By default, evaluate metric every time return True elif self._evaluate_metric_every is not None and self._evaluate_metric_every.by_epoch: # Don't evaluate if we've done so already this epoch if self._last_metric_evaluated_at_epoch == self._epoch_count: return False else: # If we haven't evaluated this epoch, check if we should return self._evaluate_metric_every.match(epoch_count=self._epoch_count) else: # This is reached when evaluate_metric_every is defined and matching by # iteration count return self._evaluate_metric_every.match(iteration_count=self._iteration_count) @evaluate_metric_now.setter def evaluate_metric_now(self, value): self._metric_evaluation_externally_triggered = bool(value) def build_metric(self, method, **kwargs): """ Builds the metric for evaluation. Parameters ---------- method : callable or str Name of the metric when string, metric class or a callable object when callable. If a name is provided, this method looks for the metric in `inferno.extensions.metrics`. kwargs : dict Keyword arguments to the metric class' constructor, if applicable. Returns ------- Trainer self. Raises ------ AssertionError: if the metric is not found. """ if callable(method): if isinstance(method, type): self._metric = method(**kwargs) else: self._metric = method elif isinstance(method, str): assert hasattr(metrics, method), \ "Could not find the metric '{}'.".format(method) self._metric = getattr(metrics, method)(**kwargs) else: raise NotImplementedError return self @property def metric_is_defined(self): """Checks if the metric is defined.""" return self._metric is not None def eval_mode(self): """Set model, criterion and metric to eval mode""" self._current_mode = 'eval' self.model.eval() if self.criterion_is_defined and isinstance(self.criterion, torch.nn.Module): self.criterion.eval() if self.metric_is_defined and isinstance(self.metric, torch.nn.Module): self.metric.eval() return self def train_mode(self): """Set model, criterion and metric to train mode""" self._current_mode = 'train' self.model.train() if self.criterion_is_defined and isinstance(self.criterion, torch.nn.Module): self.criterion.train() if self.metric_is_defined and isinstance(self.metric, torch.nn.Module): self.metric.train() return self @property def train_loader(self): assert self._loaders.get('train') is not None return self._loaders.get('train') @train_loader.setter def train_loader(self, value): assert isinstance(value, DataLoader) self._loaders.update({'train': value}) @property def validate_loader(self): assert self._loaders.get('validate') is not None return self._loaders.get('validate') @validate_loader.setter def validate_loader(self, value): assert isinstance(value, DataLoader) self._loaders.update({'validate': value}) @property def logger(self): """Gets the logger.""" return self._logger @logger.setter def logger(self, value): if isinstance(value, dict): self.build_logger(**value) else: self.build_logger(logger=value) @property def log_directory(self): """Gets the log directory.""" return self._log_directory @log_directory.setter def log_directory(self, value): """Sets the log directory,""" if value is not None: self.set_log_directory(value) @property def pickle_module(self): module_ = globals().get(self._pickle_module, None) assert_(module_ is not None, "Pickle module not found!", ModuleNotFoundError) return module_ _ALLOWED_PICKLE_MODULES = {'pickle', 'dill'} @pickle_module.setter def pickle_module(self, value): assert_(isinstance(value, str), "`pickle_module` must be set to a string.", TypeError) assert_(value in self._ALLOWED_PICKLE_MODULES, f"Pickle module must be one of {self._ALLOWED_PICKLE_MODULES}, " f"got {value} instead.", ValueError) self._pickle_module = value @property def saving_every(self): """Gets the frequency at which checkpoints are made.""" return self._save_every def save_at_best_validation_score(self, yes=True): """Sets whether to save when the validation score is the best seen.""" self._save_at_best_validation_score = yes return self @property def save_now(self): if self._save_externally_triggered: # Reset trigger self._save_externally_triggered = False # Save if externally triggered return True elif self._save_at_best_validation_score and self._is_iteration_with_best_validation_score: return True else: # Check if we're saving by epoch if self._save_every is not None and self._save_every.by_epoch: # Don't save if we've already saved once this epoch if self._epoch_count == self._last_saved_at_epoch: return False else: # If we haven't saved this epoch, check if we should return self._save_every.match(epoch_count=self._epoch_count) else: # We're saving by iterations return self._save_every is not None and \ self._save_every.match(iteration_count=self._iteration_count) @save_now.setter def save_now(self, value): """Can be set to true to trigger a checkpoint creation..""" self._save_externally_triggered = bool(value) def save_every(self, frequency, to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None): """ Set checkpoint creation frequency. Parameters ---------- frequency : inferno.utils.train_utils.Frequency or tuple or str Checkpoint creation frequency. Examples: '100 iterations' or '1 epochs'. to_directory : str Directory where the checkpoints are to be created. checkpoint_filename : str Name of the checkpoint file. best_checkpoint_filename : str Name of the best checkpoint file. Returns ------- Trainer self. """ self._save_every = tu.Frequency.build_from(frequency, priority='iterations') assert self._save_every.is_consistent self.save_to_directory(to_directory, checkpoint_filename, best_checkpoint_filename) return self @property def save_directory(self): return self._save_to_directory def save_to_directory(self, to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None): if to_directory is not None: assert_(isinstance(to_directory, str), exception_type=TypeError) if not os.path.exists(to_directory): os.makedirs(to_directory) else: assert os.path.isdir(to_directory) self._save_to_directory = to_directory if checkpoint_filename is not None: assert_(isinstance(checkpoint_filename, str), exception_type=TypeError) self._checkpoint_filename = checkpoint_filename if best_checkpoint_filename is not None: assert_(isinstance(best_checkpoint_filename, str), exception_type=TypeError) self._best_checkpoint_filename = best_checkpoint_filename return self @property def validating_every(self): return self._validate_every @property def validate_now(self): if self._validation_externally_triggered: # Reset trigger self._validation_externally_triggered = False return True elif self._validate_every is not None and self._validate_every.by_epoch: # Don't validate if we've done so already this epoch if self._last_validated_at_epoch == self._epoch_count: return False else: # If we haven't validated this epoch, check if we should return self._validate_every.match(epoch_count=self._epoch_count, match_zero=False) else: # Don't validate if we've done once already this iteration if self._last_validated_at_iteration == self._iteration_count: return False else: # If we haven't validated this iteration, check if we should. The `match_zero` is # redundant, but we'll leave it on anyway. return self._validate_every is not None and \ self._validate_every.match(iteration_count=self._iteration_count, match_zero=False) @validate_now.setter def validate_now(self, value): self._validation_externally_triggered = bool(value) def validate_every(self, frequency, for_num_iterations=None): """ Set validation frequency. Parameters ---------- frequency : inferno.utils.train_utils.Frequency or str or tuple or list or int Validation frequency. If str, it could be (say) '10 iterations' or '1 epoch'. If tuple (or list), it could be (10, 'iterations') or (1, 'epoch'). If int (say 10), it's interpreted as (10, 'iterations'). for_num_iterations : int Number of iterations to validate for. If not set, the model is validated on the entire dataset (i.e. till the data loader is exhausted). Returns ------- Trainer self """ self._validate_every = tu.Frequency.build_from(frequency, priority='iterations') assert self._validate_every.is_consistent self._num_validation_iterations = for_num_iterations return self @property def iteration_count(self): return self._iteration_count @property def epoch_count(self): return self._epoch_count @property def target_batch_dim(self): return self._target_batch_dim @target_batch_dim.setter def target_batch_dim(self, value): assert_(value in [0, 1], "target_batch_dim must be either 0 or 1, got {value} instead.".format(value=value), ValueError) self._target_batch_dim = value def set_target_batch_dim(self, value): self.target_batch_dim = value return self def build_logger(self, logger=None, log_directory=None, **kwargs): """ Build the logger. Parameters ---------- logger : inferno.trainers.callbacks.logging.base.Logger or str or type Must either be a Logger object or the name of a logger or the class of a logger. log_directory : str Path to the directory where the log files are to be stored. kwargs : dict Keyword arguments to the logger class. Returns ------- Trainer self """ if isinstance(logger, Logger): # Set logger and register with the callback engine. self._logger = logger self.callbacks.register_callback(self._logger) elif callable(logger): kwargs.update({'log_directory': log_directory}) self._logger = logger(**kwargs) self.callbacks.register_callback(self._logger) elif isinstance(logger, str): self._logger = get_logger(logger)(**kwargs) self.callbacks.register_callback(self._logger) elif logger is None: pass else: raise NotImplementedError if log_directory is not None: self.set_log_directory(log_directory) return self def set_log_directory(self, log_directory): """ Set the directory where the log files are to be stored. Parameters ---------- log_directory : str Directory where the log files are to be stored. Returns ------- Trainer self """ self._log_directory = log_directory if self._logger is not None: self._logger.set_log_directory(log_directory) return self # States that are fetched dynamically from the trainer object via properties are # dynamic states. Such states can not be updated. # The following dictionary maps state keys to the corresponding trainer attribute DYNAMIC_STATES = {'learning_rate': 'current_learning_rate'} def update_state(self, key, value): assert key not in self.DYNAMIC_STATES, \ "State at key '{}' cannot be updated because it's dynamic.".format(key) self._state.update({key: value}) return self def update_state_from_dictionary(self, dictionary): # Unwrap variables (or tensors) self._state.update({ state_key: thu.unwrap(state) for state_key, state in dictionary.items()}) def update_state_from_model_state_hooks(self): if hasattr(self.model, '_state_hooks'): state_hooks = getattr(self.model, '_state_hooks') if isinstance(state_hooks, dict): self.update_state_from_dictionary(state_hooks) def get_state(self, key, default=None): if key in self.DYNAMIC_STATES: return getattr(self, self.DYNAMIC_STATES.get(key), default) else: return self._state.get(key, default) @property def current_learning_rate(self): return self.get_current_learning_rate() def get_current_learning_rate(self): """ Gets the current learning rate. Returns ------- list or float List of learning rates if there are multiple parameter groups, or a float if there's just one. """ learning_rate = [param_group.get('lr', -1.) for param_group in self.optimizer.param_groups] learning_rate = [_learning_rate[0] if thu.is_tensor(_learning_rate) else _learning_rate for _learning_rate in learning_rate] return pyu.from_iterable(learning_rate) def to(self, device): """ Send trainer to device ---------- device : string or torch.device Target device where trainer/model should be send to """ if device == 'cuda': return self.cuda() elif device == 'cpu': return self.cpu() elif isinstance(device, torch.torch.device): self.to(device.type) else: raise NotImplementedError("Can not send trainer to device", device) def cuda(self, devices=None, base_device=None): """ Train on the GPU. Parameters ---------- devices : list Specify the ordinals of the devices to use for dataparallel training. base_device : {'cpu', 'cuda'} When using data-parallel training, specify where the result tensors are collected. If 'cuda', the results are collected in `devices[0]`. Returns ------- Trainer self """ # Validate base_device assert_(base_device in [None, 'cpu', 'cuda'], "`base_device` must either be 'cpu' or 'cuda', got {} instead." .format(base_device), DeviceError) if isinstance(devices, int) or (isinstance(devices, (list, tuple)) and len(devices) == 1): # No data-parallelism, make sure base_device is not CPU assert_(base_device != 'cpu', "Without dataparallelism, `base_device` cannot be 'cpu'.", DeviceError) self._base_device_ordinal = {None: None, 'cpu': -1, 'cuda': None}.get(base_device) # Move model to CUDA if self.model_is_defined: self.model.cuda() # Move criterion to cuda if base device ordinal is not -1 (i.e. CPU) # (the criterion is evaluated on the base device) if self.criterion_is_defined and self._base_device_ordinal != -1: self.criterion.cuda() elif self.criterion_is_defined and self._base_device_ordinal == -1: # Criterion is evaluated on the CPU, make sure that's where it lives self.criterion.cpu() self._use_cuda = True self._devices = devices return self def cpu(self): """ Train on the CPU. Returns ------- Trainer self """ if self.model_is_defined: self.model.cpu() if self.criterion_is_defined: self.criterion.cpu() self._use_cuda = False self._devices = None return self def is_cuda(self): """Returns whether using GPU for training.""" return self._use_cuda def to_device(self, objects): if isinstance(objects, (list, tuple)): return type(objects)([self.to_device(_object) for _object in objects]) else: return objects.cuda() if self._use_cuda else objects def apply_model(self, *inputs): if hasattr(self, '_base_device_ordinal'): # This is to not break old checkpoints base_device_ordinal = self._base_device_ordinal else: base_device_ordinal = None if self._devices is not None: return data_parallel(self.model, inputs, list(self._devices), output_device=base_device_ordinal) else: return self.model(*inputs) def cast(self, objects): if isinstance(objects, (list, tuple)): return type(objects)([self.cast(_object) for _object in objects]) else: # Cast only the float types, while leaving the ints alone if objects.__class__.__name__ in ['HalfTensor', 'FloatTensor', 'DoubleTensor']: cast_fn = getattr(objects, self._dtype, None) else: cast_fn = None if cast_fn is not None: return cast_fn() else: return objects def set_precision(self, dtype): """ Set training precision. Parameters ---------- dtype : {'double', 'float', 'half'} Training precision. Returns ------- Trainer self """ assert dtype in ['double', 'float', 'half'] self._dtype = dtype self._model = getattr(self._model, dtype)() return self @property def dtype(self): return self._dtype @dtype.setter def dtype(self, value): self.set_precision(value) def bind_loader(self, name, loader, num_inputs=None, num_targets=1): """ Bind a data loader to the trainer. Parameters ---------- name : {'train', 'validate', 'test'} Name of the loader, i.e. what it should be used for. loader : torch.utils.data.DataLoader DataLoader object. num_inputs : int Number of input tensors from the `loader`. num_targets : int Number of target tensors from the `loader`. Returns ------- Trainer self Raises ------ KeyError if name is invalid. TypeError if loader is not a DataLoader instance. """ assert_(name in ['train', 'validate', 'test'], "`name` must be one of ['train', 'validate', 'test']. " "Got {} instead.".format(name), KeyError) assert_(isinstance(loader, DataLoader), "`loader` must be a DataLoader object. " "Got {} instead.".format(type(loader).__name__), TypeError) # Check to see if the loader is actually new. This should usually be True. is_new_loader = loader is not self._loaders.get(name) self._loaders.update({name: loader}) # We also need to account for the case when a loader is being replaced. When this happens, # the old DataLoaderIter might still have processes running, which we need to kill. if is_new_loader and name in self._loader_iters: # This is when the previous loader already has a DataLoaderIter running. # The DataLoaderIter implements a __del__ method, which shuts down workers. del self._loader_iters[name] # Trainers loaded from pickle files might not have '_loader_specs', therefore: if not hasattr(self, '_loader_specs'): setattr(self, '_loader_specs', {}) self._loader_specs.update({name: {'num_inputs': num_inputs, 'num_targets': num_targets}}) return self def get_loader_specs(self, name): assert name in self._loader_specs.keys(), \ "Could not find specs about loader '{}'. Valid loader names are: {}" \ .format(name, set(self._loader_specs.keys())) return self._loader_specs.get(name) def fetch_next_batch(self, from_loader='train', restart_exhausted_generators=True, update_batch_count=True, update_epoch_count_if_generator_exhausted=True): # Check if the iterator is built if from_loader not in self._loader_iters: self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()}) # Try to fetch from iterator try: # Fetch next_batch = next(self._loader_iters[from_loader]) # Verify self.verify_batch(next_batch, from_loader) if update_batch_count: self._batch_count += 1 return next_batch except StopIteration: # This if clause prevents infinite recursion if the loader is empty if restart_exhausted_generators: self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()}) # Update epoch count if update_epoch_count_if_generator_exhausted: self.next_epoch() return self.fetch_next_batch(from_loader, restart_exhausted_generators=False, update_batch_count=update_batch_count) else: raise def verify_batch(self, batch, from_loader): loader_specs = self.get_loader_specs(from_loader) num_inputs = loader_specs.get('num_inputs') num_targets = loader_specs.get('num_targets') if None not in [num_inputs, num_targets]: assert len(batch) == num_inputs + num_targets, \ "Was expecting a batch with {} (= num_inputs) + {} (= num_targets) tensors, " \ "got one with {} tensors.".format(num_inputs, num_targets, len(batch)) if num_inputs is not None: assert len(batch) > num_inputs, \ "Expecting {} inputs, but the batch contains only {} tensors." \ .format(num_inputs, len(batch)) if num_targets is not None: assert len(batch) > num_targets, \ "Expecting {} outputs, but the batch contains only {} tensors." \ .format(num_targets, len(batch)) return batch def split_batch(self, batch, from_loader): loader_specs = self.get_loader_specs(from_loader) num_inputs = loader_specs.get('num_inputs') num_targets = loader_specs.get('num_targets') assert not (num_targets is None and num_inputs is None), \ "Can not split batch if both the number of inputs and targets is not known." if num_inputs is None: # Unknown number of inputs num_inputs = len(batch) - num_targets #to allow for num_targets == 0 inputs, targets = batch[:num_inputs], batch[num_inputs:] elif num_targets is None: # Unknown number of targets inputs, targets = batch[:num_inputs], batch[num_inputs:] else: # Known number of inputs and targets inputs, targets = batch[:num_inputs], batch[-num_targets:] return inputs, pyu.from_iterable(targets) def restart_generators(self, of_loader=None): if of_loader is None: of_loader = self._loaders.keys() else: assert of_loader in self._loaders.keys(), \ "Key {} not in loaders ({})".format(of_loader, list(self._loaders)) of_loader = pyu.to_iterable(of_loader) self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__() for from_loader in of_loader}) return self def wrap_batch(self, batch, from_loader=None, requires_grad=False): base_device_ordinal = \ self._base_device_ordinal if hasattr(self, '_base_device_ordinal') else None # First, send to the right device if base_device_ordinal is None: # Both inputs and labels are sent to the device batch = self.to_device(batch) elif base_device_ordinal == -1: # Input batches go to device, while labels remain on the CPU. # To start, we need the number of input batches, i.e. from_loader must not be None assert_(from_loader is not None, "`from_loader` needs to be specified if base_device_ordinal is -1 " "(i.e. base device for data-parallel training is CPU).", ValueError) loader_spec = self._loader_specs.get(from_loader) assert_(loader_spec is not None, "No `loader_spec` found for loader key '{}'.".format(from_loader), RuntimeError) num_inputs = loader_spec['num_inputs'] if num_inputs is None: num_inputs = len(batch) - loader_spec['num_targets'] # Fetch input batches and send'em to device (leave the targets alone) inputs = batch[:num_inputs] inputs = self.to_device(inputs) # Finally, build the batch batch = inputs + batch[num_inputs:] else: raise ValueError("Internal Error: Invalid base_device_ordinal: {}." .format(base_device_ordinal)) # Cast to the right dtype and return batch = self.cast(batch) # Set gradients if required variable_batch = [] for batch_num, _batch in enumerate(batch): if thu.is_tensor(_batch): variable_batch.append(_batch.requires_grad_() if requires_grad else _batch) elif pyu.is_listlike(_batch): variable_batch.append([__batch.requires_grad_() if requires_grad else __batch for __batch in _batch]) else: raise RuntimeError(f"Was Expecting batch at index {batch_num} to be either a " f"tensor or a list of tensors. Got {type(_batch)} instead.") batch = type(batch)(variable_batch) return batch def next_iteration(self): self._iteration_count += 1 def next_epoch(self): # Callback before the end of epoch self.callbacks.call(self.callbacks.END_OF_EPOCH, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count) self._epoch_count += 1 self._batch_count = 0 # Callback after the start of epoch self.callbacks.call(self.callbacks.BEGIN_OF_EPOCH, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count) def stop_fitting(self, max_num_iterations=None, max_num_epochs=None): # First priority to iteration count if max_num_iterations is not None or max_num_epochs is None: max_num_iterations = \ self._max_num_iterations if max_num_iterations is None else max_num_iterations assert_(max_num_iterations is not None, "Neither max_num_iterations nor max_num_epochs was set.", RuntimeError) return self._iteration_count >= max_num_iterations else: # max_num_epochs is specified. It could be 'auto', in which case we read from the # class attribute max_num_epochs = self._max_num_epochs \ if isinstance(max_num_epochs, str) and max_num_epochs.lower() == 'auto' \ else max_num_epochs return self._epoch_count >= max_num_epochs INF_STRINGS = {'inf', 'infinity', 'infty'} def set_max_num_iterations(self, max_num_iterations): """ Set the maximum number of training iterations. Parameters ---------- max_num_iterations : int or float or str Maximum number of training iterations. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. Returns ------- Trainer self """ max_num_iterations = \ inf if max_num_iterations in self.INF_STRINGS else max_num_iterations # Validate type assert_(isinstance(max_num_iterations, int) or max_num_iterations == inf, "max_num_iterations must be an integer or numpy.inf, got {} instead." .format(type(max_num_iterations).__name__), TypeError) self._max_num_iterations = max_num_iterations return self def set_max_num_epochs(self, max_num_epochs): """ Set the maximum number of training epochs. Parameters ---------- max_num_epochs : int or float or str Maximum number of training epochs. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. Returns ------- Trainer self """ max_num_epochs = inf if max_num_epochs in self.INF_STRINGS else max_num_epochs assert_(isinstance(max_num_epochs, int) or max_num_epochs == inf, "max_num_epochs must be an integer or numpy.inf, got {} instead." .format(type(max_num_epochs).__name__), TypeError) self._max_num_epochs = max_num_epochs return self def fit(self, max_num_iterations=None, max_num_epochs=None): """ Fit model. Parameters ---------- max_num_iterations : int or float or str (Optional) Maximum number of training iterations. Overrides the value set by `Trainer.set_max_num_iterations`. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. max_num_epochs : int or float or str (Optional) Maximum number of training epochs. Overrides the value set by `Trainer.set_max_num_epochs`. If float, it should equal numpy.inf. If str, it should be one of {'inf', 'infinity', 'infty'}. Returns ------- Trainer self """ # Takes care of: # - dispatching train # - validation # - learning rate scheduling # - saving max_num_iterations = inf if max_num_iterations in self.INF_STRINGS else max_num_iterations max_num_iterations = self._max_num_iterations if max_num_iterations is None \ else max_num_iterations max_num_epochs = inf if max_num_epochs in self.INF_STRINGS else max_num_epochs max_num_epochs = self._max_num_epochs if max_num_epochs is None else max_num_epochs self.callbacks.call(self.callbacks.BEGIN_OF_FIT, max_num_iterations=max_num_iterations, max_num_epochs=max_num_epochs) # Local clock run_num = 0 while True: if self.stop_fitting(max_num_iterations, max_num_epochs): self.console.info("Exceeded max number of iterations / epochs, breaking.") break # Train self.train_for(break_callback=lambda *args: self.stop_fitting(max_num_iterations, max_num_epochs)) # Check if it's time to validate if self.validate_now: self.console.info("Validating.") self.validate_for() # Check if it's time to save if self.save_now: self.console.info("Saving.") self.save() run_num += 1 # Call callback self.callbacks.call(self.callbacks.END_OF_FIT, max_num_iterations=max_num_iterations, max_num_epochs=max_num_epochs, num_runs=run_num) return self def apply_model_and_loss(self, inputs, target, backward=True, mode=None): if mode is None: mode = self._current_mode assert_(mode in ['train', 'eval'], f"`mode` must be one of ['train', 'eval'], got {mode} instead.", ValueError) # Compute prediction prediction = self.apply_model(*inputs) # Compute loss kwargs = {} if (isinstance(self.criterion, torch.nn.Module) and 'trainer' in signature(self.criterion.forward).parameters): kwargs['trainer'] = self if mode == 'train': loss = self.criterion(prediction, target, **kwargs) \ if len(target) != 0 else self.criterion(prediction, **kwargs) elif mode == 'eval': loss = self.validation_criterion(prediction, target, **kwargs) \ if len(target) != 0 else self.validation_criterion(prediction, **kwargs) else: raise ValueError if backward: # Backprop if required # retain_graph option is needed for some custom # loss functions like malis, False per default if self.mixed_precision: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward(retain_graph=self.retain_graph) else: loss.backward(retain_graph=self.retain_graph) return prediction, loss def train_for(self, num_iterations=None, break_callback=None): # Switch model to train mode self.train_mode() # Call callback self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_RUN, num_iterations=num_iterations) # iteration_num is a local clock. There's the global self._iteration_count that keeps # actual track of the number of iterations - this is updated by the call to # self.next_iteration(). iteration_num = 0 while True: if num_iterations is not None and iteration_num >= num_iterations: self.console.info("Finished {} iterations. Breaking...".format(num_iterations)) break # Break if break callback asks us to if break_callback is not None and break_callback(iteration_num): self.console.info("Breaking on request from callback.") break self.console.progress("Training iteration {} (batch {} of epoch {})." .format(iteration_num, self._batch_count, self._epoch_count)) # Call callback self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_ITERATION, iteration_num=iteration_num) # No interrupts while computing - a SIGINT could shoot down the driver if # done at the wrong time. Not sure if this has something to do with pinned memory with pyu.delayed_keyboard_interrupt(): # Get batch batch = self.fetch_next_batch('train') # Send to device and wrap as variable batch = self.wrap_batch(batch, from_loader='train') # Separate inputs from targets inputs, target = self.split_batch(batch, from_loader='train') # Apply model, compute loss and backprop prediction, loss = self.apply_model_and_loss(inputs, target, backward=True, mode='train') self.callbacks.call(self.callbacks.AFTER_MODEL_AND_LOSS_IS_APPLIED, prediction=prediction, loss=loss, iteration_num=iteration_num) # Compute metric if self.metric_is_defined and self.evaluate_metric_now: self._last_metric_evaluated_at_epoch = self._epoch_count # TODO Make unwrap a method for folks to overload error = self.metric(thu.unwrap(prediction, to_cpu=False), thu.unwrap(target, to_cpu=False)) self.update_state('training_error', thu.unwrap(error)) else: error = None # Update state from computation self.update_state('training_inputs', thu.unwrap(inputs)) self.update_state('training_target', thu.unwrap(target)) self.update_state('training_prediction', thu.unwrap(prediction)) self.update_state('training_loss', thu.unwrap(loss)) # Update state from model's state hooks self.update_state_from_model_state_hooks() if iteration_num % self.backprop_every == 0: # Update parameters self.optimizer.step() # Zero out the grads self.optimizer.zero_grad() # Call callback self.callbacks.call(self.callbacks.END_OF_TRAINING_ITERATION, iteration_num=iteration_num) # Prepare for next iteration self.next_iteration() # Break if validating or saving. It's important that the next_iteration() method is # called before checking validate_now and save_now - because otherwise, the iteration # counter is never updated after the first save and validate, resulting in an infinite # save + validate loop. if self.validate_now: self.console.info("Breaking to validate.") break if self.save_now: self.console.info("Breaking to save.") break iteration_num += 1 self.callbacks.call(self.callbacks.END_OF_TRAINING_RUN, num_iterations=num_iterations) return self def validate_for(self, num_iterations=None, loader_name='validate'): """ Validate for a given number of validation (if `num_iterations is not None`) or over the entire (validation) data set. Parameters ---------- num_iterations : int Number of iterations to validate for. To validate on the entire dataset, leave this as `None`. loader_name : str Name of the data loader to use for validation. 'validate' is the obvious default. Returns ------- Trainer self. """ assert_(loader_name in ['validate', 'test', 'train'], "Invalid `loader_name`: {}".format(loader_name), ValueError) # Average over errors validation_error_meter = tu.AverageMeter() validation_loss_meter = tu.AverageMeter() iteration_num = 0 num_iterations = \ self._num_validation_iterations if num_iterations is None else num_iterations # Switch to eval mode (e.g. for batchnorm, etc.) self.eval_mode() if loader_name not in self._loader_iters: self._loader_iters.update({loader_name: self._loaders[loader_name].__iter__()}) # If we don't know num_iterations, we're validating the entire dataset - so we might as # well restart the loader now if num_iterations is None: self.restart_generators(loader_name) # Record the epoch we're validating in self._last_validated_at_epoch = self._epoch_count self._last_validated_at_iteration = self._iteration_count self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_RUN, num_iterations=num_iterations, num_iterations_in_generator=len(self._loader_iters[loader_name]), last_validated_at_epoch=self._last_validated_at_epoch) while True: if num_iterations is not None and iteration_num >= num_iterations: break self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_ITERATION, iteration_num=iteration_num) try: batch = self.fetch_next_batch(loader_name, restart_exhausted_generators=num_iterations is not None, update_batch_count=False, update_epoch_count_if_generator_exhausted=False) except StopIteration: self.console.info("{} generator exhausted, breaking.".format(loader_name)) break self.console.progress("Validating iteration {}.".format(iteration_num)) # Delay SIGINTs till after computation with pyu.delayed_keyboard_interrupt(), torch.no_grad(): # Wrap batch = self.wrap_batch(batch, from_loader=loader_name) # Separate inputs, target = self.split_batch(batch, from_loader=loader_name) # Apply model, compute loss output, loss = self.apply_model_and_loss(inputs, target, backward=False, mode='eval') if isinstance(target, (list, tuple)): batch_size = target[0].size(self._target_batch_dim) else: batch_size = target.size(self._target_batch_dim) validation_loss_meter.update(thu.unwrap(loss, extract_item=True), n=batch_size) # Compute validation_error if self.metric_is_defined: validation_error = self.metric(thu.unwrap(output, to_cpu=False), thu.unwrap(target, to_cpu=False)) if torch.is_tensor(validation_error): # Convert to float validation_error = thu.unwrap(validation_error, extract_item=True) self.update_state('validation_error', thu.unwrap(validation_error)) validation_error_meter.update(validation_error, n=batch_size) self.update_state('validation_inputs', thu.unwrap(inputs)) self.update_state('validation_target', thu.unwrap(target)) self.update_state('validation_prediction', thu.unwrap(output)) self.update_state('validation_loss', thu.unwrap(loss)) # This is here for legacy reasons and will eventually be deprecated. self.update_state('validation_input', self.get_state('validation_inputs')) # Update from model's state hooks self.update_state_from_model_state_hooks() self.callbacks.call(self.callbacks.END_OF_VALIDATION_ITERATION, iteration_num=iteration_num) iteration_num += 1 self.console.info("Done validating. Logging results...") # Report validation_results = { 'validation_loss': validation_loss_meter.avg, 'validation_error': (validation_error_meter.avg if self.metric_is_defined else None) } self.record_validation_results(**validation_results) self.console.info("Validation loss: {validation_loss}; validation error: " "{validation_error}".format(**validation_results)) self.callbacks.call(self.callbacks.END_OF_VALIDATION_RUN, validation_loss_meter=validation_loss_meter, validation_error_meter=validation_error_meter if self.metric_is_defined else None) return self def record_validation_results(self, validation_loss, validation_error): # Update state self.update_state('validation_loss_averaged', thu.unwrap(validation_loss)) if validation_error is not None: self.update_state('validation_error_averaged', thu.unwrap(validation_error)) # Prefer the error metric (if provided). This should be handled with care - # validation error should either always not be None, or otherwise. validation_score = validation_loss if validation_error is None else validation_error # Check if validation error is less than the best so far if self._best_validation_score is None or validation_score < self._best_validation_score: # Best score so far. The following flag will trigger a save self._is_iteration_with_best_validation_score = True self._best_validation_score = validation_score def get_config(self, exclude_loader=True): # Returns a config dictionary, like __getstate__. Except optionally without the # data loaders (which might be yuuuuuge if it contains the data) config_dict = dict(self.__dict__) # Loader iterators can't be pickled if '_loader_iters' in config_dict: config_dict.update({'_loader_iters': {}}) if exclude_loader: if '_loaders' in config_dict: config_dict.update({'_loaders': {}}) return config_dict def set_config(self, config_dict): # TODO some sanity checks on config_dict (e.g. whether the model is actually a model, etc) self.__dict__.update(config_dict) # Rebind trainer to callback engine self.callbacks.bind_trainer(self) # Have callback engine rebind all callbacks to trainer self.callbacks.rebind_trainer_to_all_callbacks() return self def save(self, exclude_loader=True, stash_best_checkpoint=True): # Log the epoch for save_now self._last_saved_at_epoch = self._epoch_count self.callbacks.call(self.callbacks.BEGIN_OF_SAVE, save_to_directory=self._save_to_directory, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count, is_iteration_with_best_validation_score=self._is_iteration_with_best_validation_score) checkpoint_path = os.path.join(self._save_to_directory, self._checkpoint_filename) best_checkpoint_path = os.path.join(self._save_to_directory, self._best_checkpoint_filename) # Save the state dictionary torch.save(self.get_config(exclude_loader=exclude_loader), checkpoint_path, pickle_module=self.pickle_module) self.callbacks.call(self.callbacks.END_OF_SAVE, save_to_directory=self._save_to_directory, checkpoint_path=checkpoint_path, best_checkpoint_path=best_checkpoint_path, epoch_count=self._epoch_count, batch_count=self._batch_count, iteration_count=self._iteration_count, is_iteration_with_best_validation_score=self._is_iteration_with_best_validation_score) if self._is_iteration_with_best_validation_score and stash_best_checkpoint: # Do the stashin' shutil.copyfile(checkpoint_path, best_checkpoint_path) # This is required to prevent an infinite save loop? self._is_iteration_with_best_validation_score = False self.console.info("Saved to {}.".format(self._save_to_directory)) return self def save_model(self, to_directory=None): to_directory = self._save_to_directory if to_directory is None else to_directory # Save the state dictionary torch.save(self.model, os.path.join(to_directory, 'model.pytorch'), pickle_module=self.pickle_module) return self def load(self, from_directory=None, best=False, filename=None, map_location=None): """ Load the trainer from checkpoint. Parameters ---------- from_directory : str Path to the directory where the checkpoint is located. The filename should be 'checkpoint.pytorch' if best=False, or 'best_checkpoint.pytorch' if best=True. best : bool Whether to load the best checkpoint. The filename in `from_directory` should be 'best_checkpoint.pytorch'. filename : str Overrides the default filename. map_location : function, torch.device, string or a dict Specify how to remap storage locations. Returns ------- Trainer self """ from_directory = self._save_to_directory if from_directory is None else from_directory assert from_directory is not None, "Nowhere to load from." # Get file name if filename is None: filename = self._best_checkpoint_filename if best else self._checkpoint_filename # Load the dictionary config_dict = torch.load(os.path.join(from_directory, filename), pickle_module=self.pickle_module, map_location=map_location) # This is required to prevent an infinite save loop? self._is_iteration_with_best_validation_score = False # Set config self.set_config(config_dict) return self def load_model(self, from_directory=None, filename=None): from_directory = self._save_to_directory if from_directory is None else from_directory filename = 'model.pytorch' if filename is None else filename # Load the model model = torch.load(os.path.join(from_directory, filename), pickle_module=self.pickle_module) # Set model self.model = model return self def load_(self, *args, **kwargs): # Here for legacy reasons - use load instead. return self.load(*args, **kwargs) @pyu.deprecated("please use self.console.{info,progress,warning,debug} instead") def print(self, message): print("[+][{}] {}".format(str(datetime.now()), message)) @classmethod def build(cls, model=None, **trainer_config): """Factory function to build the trainer.""" # Check if trainer is to be loaded from file if trainer_config.get('load_from_checkpoint'): # Load checkpoint config trainer = cls(model).save_every(**trainer_config.get('checkpoint_config')) trainer.load_() else: trainer = cls(model) if 'logger_config' in trainer_config: trainer.build_logger(**trainer_config.get('logger_config')) if 'criterion_config' in trainer_config: trainer.build_criterion(**trainer_config.get('criterion_config')) if 'optimizer_config' in trainer_config: trainer.build_optimizer(**trainer_config.get('optimizer_config')) if 'metric_config' in trainer_config: trainer.build_metric(**trainer_config.get('metric_config')) if 'checkpoint_config' in trainer_config: trainer.save_every(**trainer_config.get('checkpoint_config')) if 'validation_config' in trainer_config: trainer.validate_every(**trainer_config.get('validation_config')) if 'max_num_iterations' in trainer_config: trainer.set_max_num_iterations(trainer_config.get('max_num_iterations')) if 'max_num_epochs' in trainer_config: trainer.set_max_num_epochs(trainer_config.get('max_num_epochs')) if trainer_config.get('use_cuda'): devices = trainer_config.get('use_cuda').get('devices') \ if isinstance(trainer_config.get('use_cuda'), dict) else None trainer.cuda(devices=devices) if 'training_precision' in trainer_config: trainer.set_precision(trainer_config.get('training_precision')) return trainer ================================================ FILE: inferno/trainers/callbacks/__init__.py ================================================ __all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients'] from .base import CallbackEngine, Callback from .console import Console from . import essentials from . import scheduling from . import gradients try: from .tqdm import TQDMProgressBar __all__.append('TQDMProgressBar') except ImportError: from .tqdmstub import TQDMProgressBar ================================================ FILE: inferno/trainers/callbacks/base.py ================================================ from ...utils import python_utils as pyu class CallbackEngine(object): """ Gathers and manages callbacks. Callbacks are callables which are to be called by trainers when certain events ('triggers') occur. They could be any callable object, but if endowed with a `bind_trainer` method, it's called when the callback is registered. It is recommended that callbacks (or their `__call__` methods) use the double-star syntax for keyword arguments. """ # Triggers BEGIN_OF_FIT = 'begin_of_fit' END_OF_FIT = 'end_of_fit' BEGIN_OF_TRAINING_RUN = 'begin_of_training_run' END_OF_TRAINING_RUN = 'end_of_training_run' BEGIN_OF_EPOCH = 'begin_of_epoch' END_OF_EPOCH = 'end_of_epoch' BEGIN_OF_TRAINING_ITERATION = 'begin_of_training_iteration' AFTER_MODEL_AND_LOSS_IS_APPLIED = 'after_model_and_loss_is_applied' END_OF_TRAINING_ITERATION = 'end_of_training_iteration' BEGIN_OF_VALIDATION_RUN = 'begin_of_validation_run' END_OF_VALIDATION_RUN = 'end_of_validation_run' BEGIN_OF_VALIDATION_ITERATION = 'begin_of_validation_iteration' END_OF_VALIDATION_ITERATION = 'end_of_validation_iteration' BEGIN_OF_SAVE = 'begin_of_save' END_OF_SAVE = 'end_of_save' TRIGGERS = {BEGIN_OF_FIT, END_OF_FIT, BEGIN_OF_TRAINING_RUN, END_OF_TRAINING_RUN, BEGIN_OF_EPOCH, END_OF_EPOCH, BEGIN_OF_TRAINING_ITERATION, AFTER_MODEL_AND_LOSS_IS_APPLIED, END_OF_TRAINING_ITERATION, BEGIN_OF_VALIDATION_RUN, END_OF_VALIDATION_RUN, BEGIN_OF_VALIDATION_ITERATION, END_OF_VALIDATION_ITERATION, BEGIN_OF_SAVE, END_OF_SAVE} def __init__(self): self._trainer = None self._callback_registry = {trigger: set() for trigger in self.TRIGGERS} self._last_known_epoch = None self._last_known_iteration = None def register_new_trigger(self, trigger_name): self.TRIGGERS.add(trigger_name) self._callback_registry.update({trigger_name: set()}) def bind_trainer(self, trainer): self._trainer = trainer return self def unbind_trainer(self): self._trainer = None return self @property def trainer_is_bound(self): return self._trainer is not None def register_callback(self, callback, trigger='auto', bind_trainer=True): assert callable(callback) # Automatic callback registration based on their methods if trigger == 'auto': automatic_registration_successful = False for trigger in self.TRIGGERS: if pyu.has_callable_attr(callback, trigger): automatic_registration_successful = True self.register_callback(callback, trigger, bind_trainer) assert automatic_registration_successful, \ "Callback could not be auto-registered: no triggers recognized." return self # Validate triggers assert trigger in self.TRIGGERS # Add to callback registry self._callback_registry.get(trigger).add(callback) # Register trainer with the callback if required bind_trainer_to_callback = self.trainer_is_bound and \ bind_trainer and \ pyu.has_callable_attr(callback, 'bind_trainer') if bind_trainer_to_callback: callback.bind_trainer(self._trainer) return self def rebind_trainer_to_all_callbacks(self): # FIXME This makes bind_trainer in register_callback reduntant, # especially if used by the trainer class, so... deprecate bind_traner. for callbacks_at_trigger in self._callback_registry.values(): for callback in callbacks_at_trigger: # Register trainer with the callback if required bind_trainer_to_callback = self.trainer_is_bound and \ pyu.has_callable_attr(callback, 'bind_trainer') if bind_trainer_to_callback: callback.bind_trainer(self._trainer) def call(self, trigger, **kwargs): assert trigger in self.TRIGGERS kwargs.update({'trigger': trigger}) for callback in self._callback_registry.get(trigger): callback(**kwargs) def get_config(self): # Pop trainer config_dict = dict(self.__dict__) config_dict.update({'_trainer': None}) return config_dict def set_config(self, config_dict): self.__dict__.update(config_dict) return self def __getstate__(self): return self.get_config() def __setstate__(self, state): self.set_config(state) class Callback(object): """Recommended (but not required) base class for callbacks.""" def __init__(self): self._trainer = None self._debugging = False self.register_instance(self) @classmethod def register_instance(cls, instance): if hasattr(cls, '_instance_registry') and instance not in cls._instance_registry: cls._instance_registry.append(instance) else: cls._instance_registry = [instance] @classmethod def get_instances(cls): if hasattr(cls, '_instance_registry'): return pyu.from_iterable(cls._instance_registry) else: return None @property def trainer(self): return self._trainer def bind_trainer(self, trainer): self._trainer = trainer return self def unbind_trainer(self): self._trainer = None return self def __call__(self, **kwargs): if 'trigger' in kwargs: if hasattr(self, kwargs.get('trigger')) and \ callable(getattr(self, kwargs.get('trigger'))): getattr(self, kwargs.get('trigger'))(**kwargs) def get_config(self): config_dict = dict(self.__dict__) config_dict.update({'_trainer': None}) return config_dict def set_config(self, config_dict): self.__dict__.update(config_dict) return self def __getstate__(self): return self.get_config() def __setstate__(self, state): self.set_config(state) def toggle_debug(self): self._debugging = not self._debugging return self def debug_print(self, message): if self._debugging: self.trainer.console.debug("[{}] {}".format(type(self).__name__, message)) ================================================ FILE: inferno/trainers/callbacks/console.py ================================================ from datetime import datetime from .base import Callback class StdoutPrinter(object): def print(self, message): print("[+][{}] {}".format(str(datetime.now()), message)) class Console(object): LEVEL_INFO = 1 LEVEL_PROGRESS = 2 LEVEL_WARNING = 3 LEVEL_DEBUG = 4 def __init__(self, printer=StdoutPrinter()): self._printer = printer self._enabled = {self.LEVEL_INFO, self.LEVEL_PROGRESS, self.LEVEL_WARNING} def set_console(self, console): self._printer = console def _print(self, message, level): if level not in self._enabled: return self._printer.print(message) def info(self, message): self._print("[INFO ] " + message, self.LEVEL_INFO) def print(self, message): self.info(message) def progress(self, message): self._print("[PROGRESS] " + message, self.LEVEL_PROGRESS) def warning(self, message): self._print("[WARNING ] " + message, self.LEVEL_WARNING) def debug(self, message): self._print("[DEBUG ] " + message, self.LEVEL_DEBUG) def _toggle(self, level, state): if state: self._enabled.add(level) else: if level in self._enabled: self._enabled.remove(level) def toggle_info(self, state): self._toggle(self.LEVEL_INFO, state) def toggle_progress(self, state): self._toggle(self.LEVEL_PROGRESS, state) def toggle_warning(self, state): self._toggle(self.LEVEL_WARNING, state) class ShowMinimalConsoleInfo(Callback): """ Callback to show only minimum training info on console viz. current epoch number, current learning rate, training loss and training error if exists. """ def __init__(self, *args, **kwargs): super(ShowMinimalConsoleInfo, self).__init__(*args, **kwargs) def begin_of_fit(self,**_): self.trainer.quiet() def end_of_epoch(self, **_): training_loss = self.trainer.get_state('training_loss') training_error = self.trainer.get_state('training_error') learning_rate = self.trainer.get_state('learning_rate') self.trainer.console.info("--------------------------------") self.trainer.console.info("Epoch "+str(self.trainer.epoch_count)) if training_loss is not None: self.trainer.console.info("Train Loss "+str(training_loss.item())) if training_error is not None: self.trainer.console.info("Train Error "+str(training_error.item())) self.trainer.console.info("Current LR "+str(learning_rate)) ================================================ FILE: inferno/trainers/callbacks/essentials.py ================================================ import numpy as np import os import h5py as h5 from ...utils import torch_utils as tu from ...utils.train_utils import Frequency from ...utils.exceptions import assert_, FrequencyValueError, NotUnwrappableError from ...utils import python_utils as pyu from .base import Callback import gc class NaNDetector(Callback): def end_of_training_iteration(self, **_): training_loss = self.trainer.get_state('training_loss') # Extract scalar if tu.is_tensor(training_loss): training_loss = tu.unwrap(training_loss, extract_item=True) if not np.isfinite(training_loss): raise RuntimeError("Loss is not finite (loss={})!".format(training_loss)) class PersistentSave(Callback): def __init__(self, template='checkpoint.pytorch.epoch{epoch_count}.iteration{iteration_count}'): super(PersistentSave, self).__init__() self.template = template def begin_of_save(self, **kwargs): self._orig_checkpoint_filename = self.trainer._checkpoint_filename self.trainer._checkpoint_filename = self.template.format(**kwargs) def end_of_save(self, save_to_directory, **_): orig_checkpoint_path = os.path.join(save_to_directory, self._orig_checkpoint_filename) if os.path.lexists(orig_checkpoint_path): os.remove(orig_checkpoint_path) os.symlink(self.trainer._checkpoint_filename, orig_checkpoint_path) self.trainer._checkpoint_filename = self._orig_checkpoint_filename class DumpHDF5Every(Callback): """Dumps intermediate training states to a HDF5 file.""" def __init__(self, frequency, to_directory, filename_template='dump.{mode}.epoch{epoch_count}.iteration{iteration_count}.h5', force_dump=False, dump_after_every_validation_run=False): super(DumpHDF5Every, self).__init__() # Privates self._dump_every = None self._trainer_states_to_be_dumped_while_training = {'training_inputs', 'training_target', 'training_prediction'} self._trainer_states_to_be_dumped_while_validating = {'validation_inputs', 'validation_target', 'validation_prediction'} self._dump_cache = {} # Publics self.dump_every = frequency self.dump_directory = to_directory self.dump_filename_template = filename_template self.force_dump = force_dump # hihi self.dump_after_every_validation_run = dump_after_every_validation_run @property def dump_every(self): return self._dump_every @dump_every.setter def dump_every(self, value): self._dump_every = Frequency.build_from(value) assert_(self._dump_every.is_consistent, "Dump frequency is not consistent.", FrequencyValueError) @property def dump_now(self): return self.dump_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True, match_zero=True) def add_to_dump_cache(self, key, value): if pyu.is_listlike(value): for value_num, _value in enumerate(value): self.add_to_dump_cache("{}_{}".format(key, value_num), _value) else: self._dump_cache.update({key: value}) def clear_dump_cache(self): self._dump_cache.clear() def dump_state(self, key, dump_while='training'): # Validate arguments keyword_mapping = {'train': 'training', 'training': 'training', 'validation': 'validating', 'validating': 'validating'} dump_while = keyword_mapping.get(dump_while) assert_(dump_while is not None, "The keyword dump_while must be one of: {}." .format(set(keyword_mapping.keys())), ValueError) assert_(isinstance(key, str), "State key must be a string, got {} instead.".format(type(key).__name__), TypeError) # Add to set of observed states if dump_while == 'training': self._trainer_states_to_be_dumped_while_training.add(key) elif dump_while == 'validating': self._trainer_states_to_be_dumped_while_validating.add(key) else: raise NotImplementedError return self def dump_states(self, keys, dump_while='training'): for key in keys: self.dump_state(key, dump_while=dump_while) return self def get_file_path(self, mode): # Make sure the dump directory exists if not os.path.exists(self.dump_directory): os.mkdir(self.dump_directory) else: assert_(os.path.isdir(self.dump_directory), "Dump directory {} is a file.".format(self.dump_directory), FileExistsError) filename = self.dump_filename_template.format(epoch_count=self.trainer.epoch_count, iteration_count=self.trainer.iteration_count, mode=mode) return os.path.join(self.dump_directory, filename) def dump(self, mode): with h5.File(name=self.get_file_path(mode), mode='w') as h5_file: for key, to_dump in self._dump_cache.items(): if to_dump is None: continue try: to_dump = tu.unwrap(to_dump, as_numpy=True) except NotUnwrappableError: # Can't unwrap to_dump, but let's not throw a tantrum if we're not required to if not self.force_dump: continue else: raise # Do the dumpin' h5_file.create_dataset(name=key, data=to_dump) def end_of_training_iteration(self, **_): dump_now = self.dump_now if dump_now: # To be double sure self.clear_dump_cache() # Get object to dump for state_name in self._trainer_states_to_be_dumped_while_training: self.add_to_dump_cache(state_name, self.trainer.get_state(state_name)) # Dump self.dump(mode='training') # Clear cache self.clear_dump_cache() def end_of_validation_run(self, **_): if self.dump_after_every_validation_run: # To be double sure self.clear_dump_cache() # Get object to dump for state_name in self._trainer_states_to_be_dumped_while_validating: self.add_to_dump_cache(state_name, self.trainer.get_state(state_name)) # Dump self.dump(mode='validation') # Clear cache self.clear_dump_cache() class SaveAtBestValidationScore(Callback): """ Triggers a save at the best EMA (exponential moving average) validation score. The basic `Trainer` has built in support for saving at the best validation score, but this callback might eventually replace that functionality. """ def __init__(self, smoothness=0, verbose=False): super(SaveAtBestValidationScore, self).__init__() # Privates self._ema_validation_score = None self._best_ema_validation_score = None # Publics self.smoothness = smoothness self.verbose = verbose def end_of_validation_run(self, **_): # Get score (i.e. validation error if available, else validation loss) current_validation_score = self.trainer.get_state('validation_error_averaged') current_validation_score = self.trainer.get_state('validation_loss_averaged') \ if current_validation_score is None else current_validation_score # Maintain ema if self._ema_validation_score is None: self._ema_validation_score = current_validation_score self._best_ema_validation_score = current_validation_score # If no previous score is known, assume this is the best score and save self.trainer._is_iteration_with_best_validation_score = True else: self._ema_validation_score = self.smoothness * self._ema_validation_score + \ (1 - self.smoothness) * current_validation_score # This overrides the default behaviour, but reduces to it if smoothness = 0 self.trainer._is_iteration_with_best_validation_score = \ self._ema_validation_score < self._best_ema_validation_score # Trigger a save if self.trainer._is_iteration_with_best_validation_score: if self.verbose: self.trainer.console.info("Current smoothed validation score {} is better " "than the best smoothed validation score {}." .format(self._ema_validation_score, self._best_ema_validation_score)) self._best_ema_validation_score = self._ema_validation_score self.trainer.save_now = True else: if self.verbose: self.trainer.console.info("Current smoothed validation score {} is not better " "than the best smoothed validation score {}." .format(self._ema_validation_score, self._best_ema_validation_score)) # Done class ParameterEMA(Callback): """Maintain a moving average of network parameters.""" def __init__(self, momentum): """ Parameters ---------- momentum : float Momentum for the moving average. The following holds: `new_moving_average = momentum * old_moving_average + (1 - momentum) * value` """ super(ParameterEMA, self).__init__() # Privates self._parameters = None # Publics self.momentum = momentum def maintain(self): if self._parameters is None: self._parameters = [p.data.new().zero_() for p in self.trainer.model.parameters()] for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters): p_ema.mul_(self.momentum).add_(p_model.data.mul(1. - self.momentum)) def apply(self): assert_(self._parameters is not None, "Can't apply parameter EMA's: not available.", ValueError) for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters): p_model.data.copy_(p_ema) def end_of_training_iteration(self, **_): self.maintain() class GradientClip(Callback): def __init__(self, clip_value=None, clip_norm=None): super(GradientClip, self).__init__() assert_(not (clip_value is None and clip_norm is None), "Must provide either clip_value or clip_norm.", ValueError) assert_(clip_value is None or clip_norm is None, f"Must provide only one, but not both: " f"clip_value ({clip_value}) or clip_norm ({clip_norm}).", RuntimeError) self._clip_value = clip_value self._clip_norm = clip_norm @property def mode(self): return 'value' if self._clip_value is not None else 'norm' @property def norm_or_value(self): return self._clip_value if self._clip_value is not None else self._clip_norm def after_model_and_loss_is_applied(self, **_): tu.clip_gradients_(self.trainer.model.parameters(), self.mode, self.norm_or_value) class GarbageCollection(Callback): """ Callback that triggers garbage collection at the end of every training iteration in order to reduce the memory footprint of training """ def end_of_training_iteration(self, **_): gc.collect() ================================================ FILE: inferno/trainers/callbacks/gradients.py ================================================ from ...utils.train_utils import Frequency from ...utils.exceptions import assert_, FrequencyValueError from .base import Callback class LogOutputGradients(Callback): """Logs the gradient of the network output""" def __init__(self, frequency): super(LogOutputGradients, self).__init__() self.log_every = frequency self.registered = False self.hook_handle = None @property def log_every(self): return self._log_every @log_every.setter def log_every(self, value): self._log_every = Frequency(value, 'iterations') assert_(self.log_every.is_consistent, "Log frequency is not consistent.", FrequencyValueError) def hook(self, module, grad_input, grad_output): #remove hook if trainer does not exits if self.trainer is None: self.hook_handle.remove() return if self.log_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True, match_zero=True): self.trainer.update_state('output_gradient', grad_output[0].detach().float().clone().cpu()) def add_hook(self): self.hook_handle = self.trainer.model.register_backward_hook(self.hook) def begin_of_fit(self, **kwargs): self._trainer.logger.observe_state("output_gradient", observe_while='training') self.add_hook() def begin_of_save(self, **_): # remove hook from model, because you can't pickle it. if self.hook_handle is not None: self.hook_handle.remove() self.hook_handle = None def end_of_save(self, **_): # add hook after model save self.add_hook() ================================================ FILE: inferno/trainers/callbacks/logging/__init__.py ================================================ __all__ = ['get_logger'] try: INFERNO_WITH_TENSORBOARD_LOGGER = True from .tensorboard import TensorboardLogger __all__.append('TensorboardLogger') except ImportError: INFERNO_WITH_TENSORBOARD_LOGGER = False def get_logger(name): if name in globals(): return globals().get(name) else: raise NotImplementedError("Logger not found.") ================================================ FILE: inferno/trainers/callbacks/logging/base.py ================================================ import os from ..base import Callback class Logger(Callback): """ A special callback for logging. Loggers are special because they're required to be serializable, whereas other callbacks have no such guarantees. In this regard, they jointly handled by trainers and the callback engine. """ def __init__(self, log_directory=None): super(Logger, self).__init__() self._log_directory = None if log_directory is not None: self.set_log_directory(log_directory) @property def log_directory(self): if self._log_directory is not None: return self._log_directory elif self.trainer is not None and self.trainer._log_directory is not None: return self.trainer._log_directory else: raise RuntimeError("No log directory found.") @log_directory.setter def log_directory(self, value): self.set_log_directory(value) def set_log_directory(self, log_directory): assert isinstance(log_directory, str) if not os.path.isdir(log_directory): assert not os.path.exists(log_directory) os.makedirs(log_directory) self._log_directory = log_directory return self ================================================ FILE: inferno/trainers/callbacks/logging/tensorboard.py ================================================ import warnings import numpy as np from torch.utils.tensorboard import SummaryWriter from .base import Logger from ....utils import torch_utils as tu from ....utils import python_utils as pyu from ....utils import train_utils as tru from ....utils.exceptions import assert_ class TaggedImage(object): def __init__(self, array, tag): self.array = array self.tag = tag class TensorboardLogger(Logger): """Class to enable logging of training progress to Tensorboard. Currently supports logging scalars and images. """ # This is hard coded because tensorboardX doesn't have a __version__ _TENSORBOARDX_IMAGE_FORMAT = 'CHW' _DEBUG = False def __init__(self, log_directory=None, log_scalars_every=None, log_images_every=None, log_histograms_every=None, send_image_at_batch_indices='all', send_image_at_channel_indices='all', send_volume_at_z_indices='mid'): """ Parameters ---------- log_directory : str Path to the directory where the log files will be placed. log_scalars_every : str or tuple or inferno.utils.train_utils.Frequency How often scalars should be logged to Tensorboard. By default, once every iteration. log_images_every : str or tuple or inferno.utils.train_utils.Frequency How often images should be logged to Tensorboard. By default, once every epoch. log_histograms_every : str or tuple or inferno.utils.train_utils.Frequency How often histograms should be logged to Tensorboard. By default, never. send_image_at_batch_indices : list or str The indices of the batches to be logged. An `image_batch` usually has the shape (num_samples, num_channels, num_rows, num_cols). By setting this argument to say [0, 2], only images corresponding to `image_batch[0]` and `image_batch[2]` are logged. When a str, it should be 'all', in which case, all samples are logged. send_image_at_channel_indices : list or str Similar to `send_image_at_batch_indices`, but applying to channels. send_volume_at_z_indices : list or str For 3D batches of shape (num_samples, num_channels, num_z_slices, num_rows, num_cols), select the indices of the z slices to be logged. When a str, it could be 'all' or 'mid' (to log the central z slice). Warnings -------- Leaving log_images_every to the default (i.e. once every iteration) might generate a large logfile and/or slow down the training. """ super(TensorboardLogger, self).__init__(log_directory=log_directory) self._log_scalars_every = None self._log_images_every = None self._log_histograms_every = None self._writer = None self._config = {'image_batch_indices': send_image_at_batch_indices, 'image_channel_indices': send_image_at_channel_indices, 'volume_z_indices': send_volume_at_z_indices} # We ought to know the trainer states we're observing (and plotting to tensorboard). # These are the defaults. self._trainer_states_being_observed_while_training = {'training_loss', 'training_error', 'training_prediction', 'training_inputs', 'training_target', 'learning_rate'} self._trainer_states_being_observed_while_validating = {'validation_error_averaged', 'validation_loss_averaged'} if log_scalars_every is not None: self.log_scalars_every = log_scalars_every if log_images_every is not None: self.log_images_every = log_images_every if log_histograms_every is not None: self.log_histograms_every = log_histograms_every @property def writer(self): if self._writer is None: self._writer = SummaryWriter(self.log_directory) return self._writer @property def log_scalars_every(self): if self._log_scalars_every is None: self._log_scalars_every = tru.Frequency(1, 'iterations') return self._log_scalars_every @log_scalars_every.setter def log_scalars_every(self, value): self._log_scalars_every = tru.Frequency.build_from(value) @property def log_scalars_now(self): # Using persistent=True in a property getter is probably not a very good idea... # We need to make sure that this getter is called only once per callback-call. return self.log_scalars_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True) @property def log_images_every(self): if self._log_images_every is None: self._log_images_every = tru.Frequency(1, 'epochs') return self._log_images_every @log_images_every.setter def log_images_every(self, value): self._log_images_every = tru.Frequency.build_from(value) @property def log_images_now(self): # Using persistent=True in a property getter is probably not a very good idea... # We need to make sure that this getter is called only once per callback-call. return self.log_images_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True) @property def log_histograms_every(self): if self._log_histograms_every is None: self._log_histograms_every = tru.Frequency('never') return self._log_histograms_every @log_histograms_every.setter def log_histograms_every(self, value): self._log_histograms_every = tru.Frequency.build_from(value) @property def log_histograms_now(self): # Using persistent=True in a property getter is probably not a very good idea... # We need to make sure that this getter is called only once per callback-call. return self.log_histograms_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True) def observe_state(self, key, observe_while='training'): # Validate arguments keyword_mapping = {'train': 'training', 'training': 'training', 'validation': 'validating', 'validating': 'validating'} observe_while = keyword_mapping.get(observe_while) assert_(observe_while is not None, "The keyword observe_while must be one of: {}." .format(set(keyword_mapping.keys())), ValueError) assert_(isinstance(key, str), "State key must be a string, got {} instead.".format(type(key).__name__), TypeError) # Add to set of observed states if observe_while == 'training': self._trainer_states_being_observed_while_training.add(key) elif observe_while == 'validating': self._trainer_states_being_observed_while_validating.add(key) else: raise NotImplementedError return self def unobserve_state(self, key, observe_while='training'): if observe_while == 'training': self._trainer_states_being_observed_while_training.remove(key) elif observe_while == 'validating': self._trainer_states_being_observed_while_validating.remove(key) else: raise NotImplementedError return self def unobserve_states(self, keys, observe_while='training'): for key in keys: self.unobserve_state(key, observe_while=observe_while) return self def observe_training_and_validation_state(self, key): for mode in ['training', 'validation']: self.observe_state('{}_{}'.format(mode, key), observe_while=mode) def observe_states(self, keys, observe_while='training'): for key in keys: self.observe_state(key, observe_while=observe_while) return self def observe_training_and_validation_states(self, keys): for key in keys: self.observe_training_and_validation_state(key) return self def log_object(self, tag, object_, allow_scalar_logging=True, allow_image_logging=True, allow_histogram_logging=True): assert isinstance(tag, str) if isinstance(object_, (list, tuple)): for object_num, _object in enumerate(object_): self.log_object("{}_{}".format(tag, object_num), _object, allow_scalar_logging, allow_image_logging, allow_histogram_logging) return # Check whether object is a scalar if tu.is_scalar_tensor(object_) and allow_scalar_logging: # Log scalar value = tu.unwrap(object_.float(), extract_item=True) self.log_scalar(tag, value, step=self.trainer.iteration_count) elif isinstance(object_, (float, int)) and allow_scalar_logging: value = float(object_) self.log_scalar(tag, value, step=self.trainer.iteration_count) elif tu.is_label_image_or_volume_tensor(object_) and allow_image_logging: # Add a channel axis and log as images self.log_image_or_volume_batch(tag, object_[:, None, ...], self.trainer.iteration_count) elif tu.is_image_or_volume_tensor(object_): if allow_image_logging: # Log images self.log_image_or_volume_batch(tag, object_, self.trainer.iteration_count) elif tu.is_vector_tensor(object_) and allow_histogram_logging: # Log histograms values = tu.unwrap(object_, as_numpy=True) self.log_histogram(tag, values, self.trainer.iteration_count) else: # Object is neither a scalar nor an image nor a vector, there's nothing we can do if tu.is_tensor(object_) and self._DEBUG: # Throw a warning when in debug mode. warnings.warn("Unsupported attempt to log tensor `{}` of shape `{}`".format(tag, object_.size())) def end_of_training_iteration(self, **_): log_scalars_now = self.log_scalars_now log_images_now = self.log_images_now log_histograms_now = self.log_histograms_now if not log_scalars_now and not log_images_now: # Nothing to log, so we won't bother return # Read states for state_key in self._trainer_states_being_observed_while_training: state = self.trainer.get_state(state_key, default=None) if state is None: # State not found in trainer but don't throw a hissy fit continue self.log_object(state_key, state, allow_scalar_logging=log_scalars_now, allow_image_logging=log_images_now, allow_histogram_logging=log_histograms_now) def end_of_validation_run(self, **_): # Log everything # Read states for state_key in self._trainer_states_being_observed_while_validating: state = self.trainer.get_state(state_key, default=None) if state is None: # State not found in trainer but don't throw a hissy fit continue self.log_object(state_key, state, allow_scalar_logging=True, allow_image_logging=True, allow_histogram_logging=False) def _tag_image(self, image, base_tag, prefix=None, instance_num=None, channel_num=None, slice_num=None): tag = base_tag if prefix is not None: tag = '{}/{}'.format(base_tag, prefix) if instance_num is not None: tag = '{}/instance_{}'.format(tag, instance_num) if channel_num is not None: tag = '{}/channel_{}'.format(tag, channel_num) if slice_num is not None: tag = '{}/slice_{}'.format(tag, slice_num) return TaggedImage(image, tag) def extract_images_from_batch(self, batch, base_tag=None, prefix=None): if base_tag is None: assert_(prefix is None, "`base_tag` is not provided - `prefix` must be None in this case.", ValueError) # Special case when batch is a list or tuple of batches if isinstance(batch, (list, tuple)): image_list = [] for batch_num, _batch in batch: image_list.extend( self.extract_images_from_batch(_batch, base_tag=base_tag, prefix='batch_{}'.format(batch_num))) return image_list # `batch` really is a tensor from now on. batch_is_image_tensor = tu.is_image_tensor(batch) batch_is_volume_tensor = tu.is_volume_tensor(batch) assert batch_is_volume_tensor != batch_is_image_tensor, \ "Batch must either be a image or a volume tensor." # Convert to numpy batch = batch.float().numpy() # Get the indices of the batches we want to send to tensorboard batch_indices = self._config.get('image_batch_indices', 'all') if batch_indices == 'all': batch_indices = list(range(batch.shape[0])) elif isinstance(batch_indices, (list, tuple)): pass elif isinstance(batch_indices, int): batch_indices = [batch_indices] else: raise NotImplementedError # Get the indices of the channels we want to send to tensorboard channel_indices = self._config.get('image_channel_indices', 'all') if channel_indices == 'all': channel_indices = list(range(batch.shape[1])) elif isinstance(channel_indices, (list, tuple)): pass elif isinstance(channel_indices, int): channel_indices = [channel_indices] else: raise NotImplementedError # Extract images from batch if batch_is_image_tensor: image_list = [(self._tag_image(image, base_tag=base_tag, prefix=prefix, instance_num=instance_num, channel_num=channel_num) if base_tag is not None else image) for instance_num, instance in enumerate(batch) for channel_num, image in enumerate(instance) if instance_num in batch_indices and channel_num in channel_indices] else: assert batch_is_volume_tensor # Trim away along the z axis z_indices = self._config.get('volume_z_indices', 'mid') if z_indices == 'all': z_indices = list(range(batch.shape[2])) elif z_indices == 'mid': z_indices = [batch.shape[2] // 2] elif isinstance(z_indices, (list, tuple)): pass elif isinstance(z_indices, int): z_indices = [z_indices] else: raise NotImplementedError # I'm going to hell for this. image_list = [(self._tag_image(image, base_tag=base_tag, prefix=prefix, instance_num=instance_num, channel_num=channel_num, slice_num=slice_num) if base_tag is not None else image) for instance_num, instance in enumerate(batch) for channel_num, volume in enumerate(instance) for slice_num, image in enumerate(volume) if instance_num in batch_indices and channel_num in channel_indices and slice_num in z_indices] # Done. return image_list def log_image_or_volume_batch(self, tag, batch, step=None): assert pyu.is_maybe_list_of(tu.is_image_or_volume_tensor)(batch) step = step or self.trainer.iteration_count image_list = self.extract_images_from_batch(batch, base_tag=tag) self.log_images(tag, image_list, step) def log_scalar(self, tag, value, step): """ Parameter ---------- tag : basestring Name of the scalar value step : int training iteration """ self.writer.add_scalar(tag=tag, scalar_value=value, global_step=step) def log_images(self, tag, images, step, image_format='CHW'): """Logs a list of images.""" assert_(image_format.upper() in ['CHW', 'HWC'], "Image format must be either 'CHW' or 'HWC'. Got {} instead.".format(image_format), ValueError) for image_num, image in enumerate(images): if isinstance(image, TaggedImage): tag = image.tag image = image.array else: tag = "{}/{}".format(tag, image_num) # This will fail for the wrong tensorboard version. image = self._order_image_axes(image, image_format, self._TENSORBOARDX_IMAGE_FORMAT) # unfortunately tensorboardX does not have a __version__ attribute # so I don't see how to check for the version and provide backwards # compatability here # tensorboardX borks if the number of image channels is not 3 # if image.shape[-1] == 1: # image = image[..., [0, 0, 0]] image = self._normalize_image(image) # print(image.dtype, image.shape) self.writer.add_image(tag, img_tensor=image, global_step=step) @staticmethod def _order_image_axes(image, image_format='CHW', target_format='CHW'): # image axis gymnastics _not_implemented_message = "target_format must be 'CHW' or 'HCW'." if image.ndim == 2: if target_format == 'CHW': # image is 2D - tensorboardX 1.4+ needs a channel axis in the front image = image[None, ...] elif target_format == 'HWC': # image is 2D - tensorboardX 1.3- needs a channel axis in the end image = image[..., None] else: raise NotImplementedError(_not_implemented_message) elif image.ndim == 3 and image_format.upper() == 'CHW': if target_format == 'CHW': # Nothing to do here pass elif target_format == 'HCW': # We have a CHW image, but need HWC. image = np.moveaxis(image, 0, 2) else: raise NotImplementedError(_not_implemented_message) elif image.ndim == 3 and image_format.upper() == 'HWC': if target_format == 'CHW': # We have a HWC image, but need CHW image = np.moveaxis(image, 2, 0) elif target_format == 'HWC': # Nothing to do here pass else: raise NotImplementedError(_not_implemented_message) else: raise RuntimeError return image @staticmethod def _normalize_image(image): normalized_image = image - image.min() maxval = normalized_image.max() if maxval > 0: normalized_image = normalized_image / maxval return normalized_image def log_histogram(self, tag, values, step, bins=1000): """Logs the histogram of a list/vector of values.""" # TODO raise NotImplementedError def get_config(self): # Apparently, some SwigPyObject objects cannot be pickled - so we need to build the # writer on the fly. config = super(TensorboardLogger, self).get_config() config.update({'_writer': None}) return config ================================================ FILE: inferno/trainers/callbacks/scheduling.py ================================================ from ...utils.train_utils import Frequency, Duration, MovingAverage from ...utils import python_utils as pyu from ...utils.exceptions import assert_, NotSetError from .base import Callback from functools import reduce class _Scheduler(Callback): def __init__(self, monitor='auto', monitor_momentum=0., monitor_while='auto'): super(_Scheduler, self).__init__() # Privates self._monitor_value_moving_average = MovingAverage(momentum=monitor_momentum) self._monitor_while = 'auto' self._monitor = 'auto' # Publics self.monitor = monitor self.monitor_while = monitor_while @property def monitor(self): assert_(self._monitor is not None, "Monitor is not set yet.", NotSetError) return self._monitor @monitor.setter def monitor(self, value): self._monitor = value @property def monitor_value(self): return self.get_monitor_value()[0] @property def monitor_while(self): if self._monitor_while == 'auto': monitor_value, monitor = self.get_monitor_value() if monitor.startswith('training_'): self._monitor_while = 'training' elif monitor.startswith('validation_'): self._monitor_while = 'validation' else: raise RuntimeError("Could not parse `monitor_while`. " "Please provide one manually.") return self._monitor_while @monitor_while.setter def monitor_while(self, value): value_mapping = {'auto': 'auto', 'training': 'training', 'validation': 'validation', 'validating': 'validation'} value = value_mapping.get(value) assert_(value is not None, "`monitor_while` must be one of {}, got {} instead." .format(value_mapping.keys(), value), ValueError) self._monitor_while = value def get_monitor_value(self): if self._monitor == 'auto': # Try to get validation error monitor_value = self.trainer.get_state('validation_error_averaged') if monitor_value is not None: self._monitor = 'validation_error_averaged' return monitor_value, self._monitor monitor_value = self.trainer.get_state('validation_loss_averaged') if monitor_value is not None: self._monitor = 'validation_loss_averaged' return monitor_value, self._monitor monitor_value = self.trainer.get_state('training_error') if monitor_value is not None: self._monitor = 'training_error' return monitor_value, self._monitor monitor_value = self.trainer.get_state('training_loss') if monitor_value is not None: self._monitor = 'training_loss' return monitor_value, self._monitor else: raise RuntimeError("Could not auto-fetch a monitor_value. " "Please specify a monitor manually.") else: monitor_value = self.trainer.get_state(self._monitor) assert_(monitor_value is not None, "Could not fetch the specified monitor ('{}') from trainer's state." .format(self._monitor), ValueError) return monitor_value, self._monitor def maintain_monitor_moving_average(self): monitor_value = self.monitor_value self._monitor_value_moving_average.update(monitor_value) return monitor_value class AutoLR(_Scheduler): """ Callback to decay or hike the learning rate automatically when a specified monitor stops improving. The monitor should be decreasing, i.e. lower value --> better performance. """ def __init__(self, factor, patience, required_minimum_relative_improvement=0, consider_improvement_with_respect_to='best', cooldown_duration=None, monitor='auto', monitor_momentum=0, monitor_while='auto', exclude_param_groups=None, verbose=False): """ Parameters ---------- factor : float Factor to multiply the learning rate with when out of patience and not in cooldown. Setting `factor < 1` results in a LR decay, whereas setting `factor > 1` results in a LR hike. patience : str or tuple or inferno.utils.train_utils.Duration Specifies how long to wait for an improvement before a LR decay is triggered. required_minimum_relative_improvement : float Specifies by how much (as a fraction of the current value) the monitor should improve to consider the improvement significant. Leaving this to zero implies the monitor will be considered improving even if it's only so slightly better. consider_improvement_with_respect_to : {'best', 'previous'} While determining if the monitor has improved, the improvement is considered with respect to this value. Could be 'best' or 'previous'. cooldown_duration: str or tuple or inferno.utils.train_utils.Duration Wait for this duration to resume operation after having decayed LR. monitor : str Specifies the monitor. Monitor must be a trainer state, and decrease with increasing performance. Examples: 'validation_error', 'training_loss'. The monitor can be 'auto' in which case it's recommended that you specify `monitor_while`. monitor_momentum : float A momentum to smooth the monitor history with. Usually recommended to smooth out any fluctuations in the monitor value. monitor_while : {'auto', 'training', 'validating'} Whether to monitor while training or validating. If the monitor is specified (i.e. is not 'auto'), this can be left to 'auto'. exclude_param_groups : int or list Parameter groups to __not__ apply the LR decay on. verbose : bool Specifies if a message be printed before decaying. """ super(AutoLR, self).__init__(monitor=monitor, monitor_momentum=monitor_momentum, monitor_while=monitor_while) # Validate assert_(consider_improvement_with_respect_to in ['best', 'previous'], "`consider_improvement_with_respect_to` must be either 'best' or 'previous', " "and not {}".format(consider_improvement_with_respect_to), ValueError) # Privates self._patience = None self._cooldown = None self._last_decayed_at = {'iteration_count': None, 'epoch_count': None} self._last_improved_at = {'iteration_count': None, 'epoch_count': None} self._best_monitor_value = None # Publics self.patience = patience self.cooldown_duration = cooldown_duration self.factor = factor self.required_minimum_relative_improvement = required_minimum_relative_improvement self.consider_improvement_with_respect_to = consider_improvement_with_respect_to self.exclude_param_groups = pyu.to_iterable(exclude_param_groups) \ if exclude_param_groups is not None else None self.verbose = verbose @property def patience(self): assert_(self._patience is not None, "Patience is not set yet.", NotSetError) return self._patience @patience.setter def patience(self, value): self._patience = Duration.build_from(value) @property def cooldown_duration(self): return self._cooldown @cooldown_duration.setter def cooldown_duration(self, value): if value is not None: self._cooldown = Duration.build_from(value) @property def duration_since_last_decay(self): since_last_decayed = {} if self._last_decayed_at.get('iteration_count') is None: since_last_decayed.update({'iteration_count': self.trainer.iteration_count}) else: since_last_decayed.update( {'iteration_count': (self.trainer.iteration_count - self._last_decayed_at['iteration_count']) }) if self._last_decayed_at.get('epoch_count') is None: since_last_decayed.update({'epoch_count': self.trainer.epoch_count}) else: since_last_decayed.update( {'epoch_count': (self.trainer.epoch_count - self._last_decayed_at['epoch_count']) }) return since_last_decayed @property def duration_since_last_improvment(self): since_last_improved = {} if self._last_improved_at.get('iteration_count') is None: since_last_improved.update({'iteration_count': self.trainer.iteration_count}) else: since_last_improved.update( {'iteration_count': (self.trainer.iteration_count - self._last_improved_at['iteration_count']) }) if self._last_improved_at.get('epoch_count') is None: since_last_improved.update({'epoch_count': self.trainer.epoch_count}) else: since_last_improved.update( {'epoch_count': (self.trainer.epoch_count - self._last_improved_at['epoch_count']) }) return since_last_improved @property def out_of_patience(self): return self.patience.match(**self.duration_since_last_improvment) @property def in_cooldown(self): if self.cooldown_duration is not None: return not self.cooldown_duration.match(**self.duration_since_last_decay) else: return False def decay(self): exclude_param_groups = \ [] if self.exclude_param_groups is None else list(self.exclude_param_groups) for param_group_num, param_group in enumerate(self.trainer.optimizer.param_groups): if param_group_num not in exclude_param_groups: param_group['lr'] *= self.factor self.debug_print("Decayed LR of param_group {} from {} --> {}" .format(param_group_num, param_group['lr'] / self.factor, param_group['lr'])) self._last_decayed_at.update({'iteration_count': self.trainer.iteration_count, 'epoch_count': self.trainer.epoch_count}) def maintain_monitor_moving_average(self): monitor_value = super(AutoLR, self).maintain_monitor_moving_average() if self._best_monitor_value is None: self._best_monitor_value = monitor_value @property def monitor_value_has_significantly_improved(self): if self._monitor_value_moving_average.previous is None: # There's nothing to compare with return True else: improvement_baseline = \ self._best_monitor_value \ if self.consider_improvement_with_respect_to == 'best' else \ self._monitor_value_moving_average.previous monitor_value_has_significantly_improved = \ self.is_significantly_less_than(self._monitor_value_moving_average.val, improvement_baseline, self.required_minimum_relative_improvement) self.debug_print("Is {} significantly less than {} with min_relative_delta = {}? {}." .format(self._monitor_value_moving_average.val, improvement_baseline, self.required_minimum_relative_improvement, monitor_value_has_significantly_improved)) # monitor_value_has_significantly_improved could be False, even if the current # moving average is less than the best monitor value, if the improvement is not # significant enough self._best_monitor_value = min([self._best_monitor_value, self._monitor_value_moving_average.val]) if monitor_value_has_significantly_improved: self._last_improved_at.update({'iteration_count': self.trainer.iteration_count, 'epoch_count': self.trainer.epoch_count}) return monitor_value_has_significantly_improved def end_of_training_iteration(self, **_): # Decay if we're not in cooldown (and monitoring while training) if self.monitor_while == 'training': self.maintain_monitor_moving_average() if not self.monitor_value_has_significantly_improved and \ self.out_of_patience and not self.in_cooldown: if self.verbose: self.trainer.console.info("Monitor '{}' has not significantly improved, decaying LR." .format(self.monitor)) self.decay() def end_of_validation_run(self, **_): if self.monitor_while == 'validation': self.maintain_monitor_moving_average() if not self.monitor_value_has_significantly_improved \ and self.out_of_patience and not self.in_cooldown: if self.verbose: self.trainer.console.info("Monitor '{}' has not significantly improved " "({} vs. {}), decaying LR." .format(self.monitor, self._monitor_value_moving_average.val, self._best_monitor_value)) self.decay() @staticmethod def is_significantly_less_than(x, y, min_relative_delta): eps = 1.e-6 if x > y: return False relative_delta = abs(y - x) / (abs(y) + eps) return relative_delta > min_relative_delta class AutoLRDecay(AutoLR): """ Callback to decay the learning rate automatically when a specified monitor stops improving. The monitor should be decreasing, i.e. lower value --> better performance. """ pass class DecaySpec(object): """A class to specify when to decay (or hike) LR and by what factor.""" def __init__(self, duration, factor): # Privates self._matched = False # Publics self.duration = Duration.build_from(duration) self.factor = factor def match(self, iteration_count=None, epoch_count=None, when_equal_return=True): match_result = self.duration.match(iteration_count=iteration_count, epoch_count=epoch_count, when_equal_return=when_equal_return) if match_result and not self._matched: # First match self._matched = True return match_result else: # Already matched once (or more often) return False def new(self): return type(self)(self.duration, self.factor) @classmethod def build_from(cls, args): if isinstance(args, (list, tuple)): return cls(*args) elif isinstance(args, dict): return cls(**args) elif isinstance(args, cls): return args else: raise NotImplementedError("Can't build DecaySpec from {}.".format(type(args))) class ManualLR(Callback): def __init__(self, decay_specs, exclude_param_groups=None): super(ManualLR, self).__init__() self.decay_specs = [DecaySpec.build_from(decay_spec) for decay_spec in pyu.to_iterable(decay_specs)] self.exclude_param_groups = pyu.to_iterable(exclude_param_groups) \ if exclude_param_groups is not None else None def match(self): # Find the decayspec that matched matched = [decay_spec for decay_spec in self.decay_specs if decay_spec.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count)] if matched: # Allow for more than one matches; in which case the factors are multiplied global_factor = reduce(lambda x, y: x * y, [matched_decay_spec.factor for matched_decay_spec in matched]) return True, global_factor else: return False, None def decay(self, factor): exclude_param_groups = \ [] if self.exclude_param_groups is None else list(self.exclude_param_groups) for param_group_num, param_group in enumerate(self.trainer.optimizer.param_groups): if param_group_num not in exclude_param_groups: param_group['lr'] *= factor self.debug_print("Decayed LR of param_group {} from {} --> {}" .format(param_group_num, param_group['lr'] / factor, param_group['lr'])) def end_of_training_iteration(self, **_): matched, global_factor = self.match() if matched: assert global_factor is not None self.decay(global_factor) class SaveModelRegularly(Callback): """saves the network weights in regular intervals""" def __init__(self, frequency): super().__init__() self._save_every = Frequency.build_from(frequency) @property def save_now(self): return self._save_every.match(iteration_count=self.trainer.iteration_count, epoch_count=self.trainer.epoch_count, persistent=True, match_zero=True) def end_of_training_iteration(self, **_): if self.save_now: self.trainer.save_model() ================================================ FILE: inferno/trainers/callbacks/tqdm.py ================================================ from .base import Callback from tqdm import tqdm from datetime import datetime from .console import Console class TQDMPrinter(object): def __init__(self, progress): self._progress = progress def print(self, message): if self._progress.outer_bar is not None: self._progress.outer_bar.clear() tqdm.write(message) if self._progress.outer_bar is not None: self._progress.outer_bar.refresh() class TQDMConsole(Console): def __init__(self): super(TQDMConsole, self).__init__(printer=TQDMPrinter(TQDMProgressBar())) class TQDMProgressBar(Callback): def __init__(self, *args, **kwargs): super(TQDMProgressBar, self).__init__(*args, **kwargs) self.epoch_bar = None self.outer_bar = None self.is_training = False self.is_validation = False def bind_trainer(self, *args, **kwargs): super(TQDMProgressBar, self).bind_trainer(*args, **kwargs) self.trainer.console.toggle_progress(False) self.trainer.console.set_console(TQDMPrinter(self)) def _init_epoch_bar_train(self): n_batch = len(self.trainer._loader_iters['train']) self.epoch_bar = tqdm(total=n_batch, position=1, dynamic_ncols=True) self.epoch_bar.update(self.trainer._batch_count) self.epoch_bar.set_description("Training epoch %d" % self.trainer._epoch_count) def print(self, message, **_): if self.outer_bar is not None: self.outer_bar.clear() tqdm.write("[+][{}] {}".format(str(datetime.now()), message)) if self.outer_bar is not None: self.outer_bar.refresh() def begin_of_fit(self, max_num_epochs, **_): if isinstance(max_num_epochs, int): self.outer_bar = tqdm(total=max_num_epochs, position=0, dynamic_ncols=True) else: self.outer_bar = tqdm(total=1000, position=0, dynamic_ncols=True) self.outer_bar.set_description("Epochs") def end_of_fit(self, **_): if self.outer_bar is not None: self.outer_bar.close() self.outer_bar = None def begin_of_epoch(self, **_): if self.epoch_bar is not None: self.epoch_bar.close() def end_of_epoch(self, **_): if self.outer_bar is not None: self.outer_bar.update(1) def begin_of_training_iteration(self, **_): if not self.epoch_bar and 'train' in self.trainer._loader_iters.keys(): self._init_epoch_bar_train() return if self.epoch_bar: self.epoch_bar.update(1) def begin_of_validation_iteration(self, **_): if self.epoch_bar: self.epoch_bar.update(1) def begin_of_training_run(self, **_): self.is_training = True def end_of_training_run(self, **_): self.is_training = False if self.epoch_bar: self.epoch_bar.close() self.epoch_bar = None def begin_of_validation_run(self, num_iterations, num_iterations_in_generator, last_validated_at_epoch, **_): self.is_validation = True nmax = num_iterations if not nmax: nmax = num_iterations_in_generator self.epoch_bar = tqdm(total=nmax, position=1, dynamic_ncols=True) self.epoch_bar.set_description("Validating epoch %d" % (last_validated_at_epoch-1)) def end_of_validation_run(self, **_): self.is_validation = False if self.epoch_bar: self.epoch_bar.close() self.epoch_bar = None ================================================ FILE: inferno/trainers/callbacks/tqdmstub.py ================================================ from .base import Callback class TQDMProgressBar(Callback): def __init__(self, *args, **kwargs): super(TQDMProgressBar, self).__init__(*args, **kwargs) def bind_trainer(self, *args, **kwargs): super(TQDMProgressBar, self).bind_trainer(*args, **kwargs) self.trainer.console.warning("tqdm is not installed. will fall back to normal stdout console.") def begin_of_fit(self, **_): pass ================================================ FILE: inferno/utils/__init__.py ================================================ ================================================ FILE: inferno/utils/exceptions.py ================================================ """Exceptions and Error Handling""" def assert_(condition, message='', exception_type=AssertionError): """Like assert, but with arbitrary exception types.""" if not condition: raise exception_type(message) # ------ VALUE ERRORS ------ class ShapeError(ValueError): pass class FrequencyValueError(ValueError): pass class DeviceError(ValueError): pass class NotSetError(ValueError): pass # ------ TYPE ERRORS ------ class NotTorchModuleError(TypeError): pass class FrequencyTypeError(TypeError): pass class DTypeError(TypeError): pass # ------ LOOKUP ERRORS ------ class ClassNotFoundError(LookupError): pass # ------ NOT-IMPLEMENTED ERRORS ------ class NotUnwrappableError(NotImplementedError): pass ================================================ FILE: inferno/utils/io_utils.py ================================================ import os import h5py as h5 import numpy as np import yaml from skimage.io import imsave # Function to load in a dataset from a h5file def fromh5(path, datapath=None, dataslice=None, asnumpy=True, preptrain=None): """ Opens a hdf5 file at path, loads in the dataset at datapath, and returns dataset as a numpy array. """ # Check if path exists (thanks Lukas!) assert os.path.exists(path), "Path {} does not exist.".format(path) with h5.File(path, 'r') as f: # Init dataset h5dataset = f[datapath] if datapath is not None else f.values()[0] # Slice dataset h5dataset = h5dataset[dataslice] if dataslice is not None else h5dataset # Convert to numpy if required h5dataset = np.asarray(h5dataset) if asnumpy else h5dataset # Apply preptrain h5dataset = preptrain(h5dataset) if preptrain is not None else h5dataset return h5dataset # TODO we could also do **h5_kwargs instead def toh5(data, path, datapath='data', compression=None, chunks=None): """Write `data` to a HDF5 volume.""" with h5.File(path) as f: f.create_dataset(datapath, data=data, compression=compression, chunks=chunks) def fromz5(path, datapath, dataslice=None, n_threads=8): # we import z5py only here because we don't want to assume that it's in the env import z5py assert os.path.exists(path), "Path {} does not exist.".format(path) with z5py.File(path) as f: ds = f[datapath] ds.n_threads = n_threads data = ds[:] if dataslice is None else ds[dataslice] return data # Yaml to dict reader def yaml2dict(path): if isinstance(path, dict): # Forgivable mistake that path is a dict already return path with open(path, 'r') as f: readict = yaml.load(f, Loader=yaml.FullLoader) return readict def print_tensor(tensor, prefix, directory): """Prints a image or volume tensor to file as images.""" def _print_image(image, prefix, batch, channel, z=None): if z is None: file_name = "{}--B-{}--CH-{}.png".format(prefix, batch, channel) else: file_name = "{}--B-{}--CH-{}--Z-{}.png".format(prefix, batch, channel, z) full_file_name = os.path.join(directory, file_name) imsave(arr=image, fname=full_file_name) for batch in range(tensor.shape[0]): for channel in range(tensor.shape[1]): if tensor.ndim == 4: _print_image(tensor[batch, channel, ...], prefix, batch, channel) else: for plane in range(tensor.shape[2]): _print_image(tensor[batch, channel, plane, ...], prefix, batch, channel, plane) ================================================ FILE: inferno/utils/math_utils.py ================================================ def max_allowed_ds_steps(shape, factor): """How often can a shape be down-sampled by a given factor such that non of the divisions will give non-integers. Args: shape (listlike): tensor shape factor (integer): downsample factor Returns: int: maximum allowed downsample operations """ def max_allowed_ds_steps_impl(size, factor): current_size = float(size) allowed_steps = 0 while(True): new_size = current_size / float(factor) if(new_size >=1 and new_size.is_integer()): current_size = new_size allowed_steps += 1 else: break return allowed_steps min_steps = float('inf') for s in shape: min_steps = int(min(min_steps, max_allowed_ds_steps_impl(s, factor))) return min_steps ================================================ FILE: inferno/utils/model_utils.py ================================================ import torch from .exceptions import assert_, NotTorchModuleError, ShapeError def is_model_cuda(model): try: return next(model.parameters()).is_cuda except StopIteration: # Assuming that if a network has no parameters, it doesn't use CUDA return False class ModelTester(object): def __init__(self, input_shape, expected_output_shape): self._is_cuda = False self.input_shape = input_shape self.expected_output_shape = expected_output_shape def cuda(self): self._is_cuda = True return self def get_input(self): with torch.no_grad(): if self._is_cuda: return torch.rand(*self.input_shape, requires_grad=False).cuda() else: return torch.rand(*self.input_shape, requires_grad=False) def __call__(self, model): # Make sure model is a model assert_(isinstance(model, torch.nn.Module), "Model is not a torch module.", NotTorchModuleError) # Transfer to cuda if required if not is_model_cuda(model) and self._is_cuda: model.cuda() input_ = self.get_input() output = model(input_) assert_(list(output.size()) == list(self.expected_output_shape), "Expected output shape {} for input shape {}, " "got output of shape {} instead.".format(list(self.expected_output_shape), list(self.input_shape), list(output.size())), ShapeError) return model class MultiscaleModelTester(ModelTester): def __call__(self, model): # Make sure model is a model assert_(isinstance(model, torch.nn.Module), "Model is not a torch module.", NotTorchModuleError) # Transfer to cuda if required if not is_model_cuda(model) and self._is_cuda: model.cuda() input_ = self.get_input() output = model(input_) assert_(isinstance(output, tuple), "Expect tuple output") for scale in range(len(output)): assert_(list(output[scale].size()) == list(self.expected_output_shape[scale]), "Expected output shape {} for input shape {}, " "got output of shape {} instead.".format(list(self.expected_output_shape[scale]), list(self.input_shape), list(output[scale].size())), ShapeError) return model ================================================ FILE: inferno/utils/partial_cls.py ================================================ import functools import sys import types import inspect __all__ = [ 'partial_cls', 'register_partial_cls' ] def partial_cls(base_cls, name, module, fix=None, default=None): # helper function def insert_if_not_present(dict_a, dict_b): for kw,val in dict_b.items(): if kw not in dict_a: dict_a[kw] = val return dict_a # helper function def insert_call_if_present(dict_a, dict_b, callback): for kw,val in dict_b.items(): if kw not in dict_a: dict_a[kw] = val else: callback(kw) return dict_a # helper class class PartialCls(object): def __init__(self, base_cls, name, module, fix=None, default=None): self.base_cls = base_cls self.name = name self.module = module self.fix = [fix, {}][fix is None] self.default = [default, {}][default is None] if self.fix.keys() & self.default.keys(): raise TypeError('fix and default share keys') # remove binded kw self._allowed_kw = self._get_allowed_kw() def _get_allowed_kw(self): argspec = inspect.getfullargspec(base_cls.__init__) args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = argspec if varargs is not None: raise TypeError('partial_cls can only be used if __init__ has no varargs') if varkw is not None: raise TypeError('partial_cls can only be used if __init__ has no varkw') if kwonlyargs is not None and kwonlyargs != []: raise TypeError('partial_cls can only be used without kwonlyargs') if args is None or len(args) < 1: raise TypeError('seems like self is missing') return [kw for kw in args[1:] if kw not in self.fix] def _build_kw(self, args, kwargs): # handle *args if len(args) > len(self._allowed_kw): raise TypeError("to many arguments") all_args = {} for arg, akw in zip(args, self._allowed_kw): all_args[akw] = arg # handle **kwargs intersection = self.fix.keys() & kwargs.keys() if len(intersection) >= 1: kw = intersection.pop() raise TypeError("`{}.__init__` got unexpected keyword argument '{}'".format(name, kw)) def raise_cb(kw): raise TypeError("{}.__init__ got multiple values for argument '{}'".format(name, kw)) all_args = insert_call_if_present(all_args, kwargs, raise_cb) # handle fixed arguments def raise_cb(kw): raise TypeError() all_args = insert_call_if_present(all_args, self.fix, raise_cb) # handle defaults all_args = insert_if_not_present(all_args, self.default) # handle fixed all_args.update(self.fix) return all_args def build_cls(self): def new_init(self_of_new_cls, *args, **kwargs): combined_args = self._build_kw(args=args, kwargs=kwargs) #call base cls init super(self_of_new_cls.__class__, self_of_new_cls).__init__(**combined_args) return type(name, (self.base_cls,), { '__module__': self.module, '__init__' : new_init }) return cls return PartialCls(base_cls=base_cls, name=name, module=module, fix=fix, default=default).build_cls() def register_partial_cls(base_cls, name, module, fix=None, default=None): module_dict = sys.modules[module].__dict__ generatedClass = partial_cls(base_cls=base_cls,name=name, module=module, fix=fix, default=default) module_dict[generatedClass.__name__] = generatedClass del generatedClass if __name__ == "__main__": class Conv(object): def __init__(self, dim, activation, stride=1): print(f"dim {dim} act {activation} stride {stride}") Conv2D = partial_cls(Conv,'Conv2D',__name__, fix=dict(dim=2), default=dict(stride=2)) #obj = Conv2D(activation='a') #obj = Conv2D('a',activation='a', stride=3) obj = Conv2D('fu','bar') ================================================ FILE: inferno/utils/python_utils.py ================================================ """Utility functions with no external dependencies.""" import signal import warnings import functools import inspect import os from threading import current_thread, main_thread def ensure_dir(directory): """ensure the existence of e directory at a given path If the directory does not exist it is created Args: directory (str): path of the directory Returns: str: path of the directory """ if not os.path.exists(directory): os.makedirs(directory) return directory def require_dict_kwargs(kwargs, msg=None): """ Ensure arguments passed kwargs are either None or a dict. If arguments are neither a dict nor None a RuntimeError is thrown Args: kwargs (object): possible dict or None msg (None, optional): Error msg Returns: dict: kwargs dict Raises: RuntimeError: if the passed value is neither a dict nor None this error is raised """ if kwargs is None: return dict() elif isinstance(kwargs, dict): return kwargs else: if msg is None: raise RuntimeError("value passed as keyword argument dict is neither None nor a dict") else: raise RuntimeError("%s"%str(msg)) def is_listlike(x): return isinstance(x, (list, tuple)) def to_iterable(x): return [x] if not is_listlike(x) else x def from_iterable(x): return x[0] if (is_listlike(x) and len(x) == 1) else x def robust_len(x): return len(x) if is_listlike(x) else 1 def as_tuple_of_len(x, len_): if is_listlike(x): assert len(x) == len_, \ "Listlike object of len {} can't be returned " \ "as a tuple of length {}.".format(len(x), len_) return tuple(x) else: return (x,) * len_ def has_callable_attr(object_, name): return hasattr(object_, name) and callable(getattr(object_, name)) def is_maybe_list_of(check_function): def decorated_function(object_, **kwargs): if isinstance(object_, (list, tuple)): return all([check_function(_object, **kwargs) for _object in object_]) else: return check_function(object_, **kwargs) return decorated_function class delayed_keyboard_interrupt(object): """ Delays SIGINT over critical code. Borrowed from: https://stackoverflow.com/questions/842557/ how-to-prevent-a-block-of-code-from-being-interrupted-by-keyboardinterrupt-in-py """ # PEP8: Context manager class in lowercase def __enter__(self): if current_thread() is main_thread(): self.signal_received = False self.old_handler = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGINT, self.handler) def handler(self, sig, frame): self.signal_received = (sig, frame) def __exit__(self, type, value, traceback): if current_thread() is main_thread(): signal.signal(signal.SIGINT, self.old_handler) if self.signal_received: self.old_handler(*self.signal_received) def get_config_for_name(config, name): config_for_name = {} for key, val in config.items(): if isinstance(val, dict) and name in val: # we leave the slicing_config validation to classes higher up in MRO config_for_name.update({key: val.get(name)}) else: config_for_name.update({key: val}) return config_for_name string_types = (type(b''), type(u'')) def deprecated(reason): """ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used. Borrowed from https://stackoverflow.com/questions/2536307/ decorators-in-the-python-standard-lib-deprecated-specifically by Laurent LAPORTE https://stackoverflow.com/users/1513933/laurent-laporte """ if isinstance(reason, string_types): # The @deprecated is used with a 'reason'. # # .. code-block:: python # # @deprecated("please, use another function") # def old_function(x, y): # pass def decorator(func1): if inspect.isclass(func1): fmt1 = "Call to deprecated class {name} ({reason})." else: fmt1 = "Call to deprecated function {name} ({reason})." @functools.wraps(func1) def new_func1(*args, **kwargs): warnings.simplefilter('always', DeprecationWarning) warnings.warn( fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2 ) warnings.simplefilter('default', DeprecationWarning) return func1(*args, **kwargs) return new_func1 return decorator elif inspect.isclass(reason) or inspect.isfunction(reason): # The @deprecated is used without any 'reason'. # # .. code-block:: python # # @deprecated # def old_function(x, y): # pass func2 = reason if inspect.isclass(func2): fmt2 = "Call to deprecated class {name}." else: fmt2 = "Call to deprecated function {name}." @functools.wraps(func2) def new_func2(*args, **kwargs): warnings.simplefilter('always', DeprecationWarning) warnings.warn( fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2 ) warnings.simplefilter('default', DeprecationWarning) return func2(*args, **kwargs) return new_func2 else: raise TypeError(repr(type(reason))) ================================================ FILE: inferno/utils/test_utils.py ================================================ import torch from torch.utils.data.dataset import TensorDataset from torch.utils.data.dataloader import DataLoader import numpy as np def generate_random_data(num_samples, shape, num_classes, hardness=0.3, dtype=None): """Generate a random dataset with a given hardness and number of classes.""" dataset_input = np.zeros((num_samples,) + shape, dtype=dtype) dataset_target = np.random.randint(num_classes, size=num_samples) for sample_num in range(num_samples): dataset_input[sample_num] = np.random.normal(loc=dataset_target[sample_num], scale=(1 - hardness), size=shape) return dataset_input, dataset_target def generate_random_dataset(num_samples, shape, num_classes, hardness=0.3, dtype=None): """Generate a random dataset with a given hardness and number of classes.""" # Generate numpy arrays dataset_input, dataset_target = generate_random_data(num_samples, shape, num_classes, hardness=hardness, dtype=dtype) # Convert to tensor and build dataset dataset = TensorDataset(torch.from_numpy(dataset_input), torch.from_numpy(dataset_target)) return dataset def generate_random_dataloader(num_samples, shape, num_classes, hardness=0.3, dtype=None, batch_size=1, shuffle=False, num_workers=0, pin_memory=False): """Generate a loader with a random dataset of given hardness and number of classes.""" dataset = generate_random_dataset(num_samples, shape, num_classes, hardness=hardness, dtype=dtype) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) return dataloader ================================================ FILE: inferno/utils/torch_utils.py ================================================ import numpy as np import torch from .python_utils import delayed_keyboard_interrupt from .exceptions import assert_, ShapeError, NotUnwrappableError def unwrap(input_, to_cpu=True, as_numpy=False, extract_item=False): if isinstance(input_, (list, tuple)): return type(input_)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy) for _t in input_]) elif torch.is_tensor(input_): tensor = input_ elif isinstance(input_, np.ndarray): return input_ elif isinstance(input_, (float, int)): return input_ else: raise NotUnwrappableError("Cannot unwrap a '{}'." .format(type(input_).__name__)) # Transfer to CPU if required if to_cpu: with delayed_keyboard_interrupt(): tensor = tensor.cpu().detach() # Convert to numpy if required if as_numpy: return tensor.cpu().detach().numpy() elif extract_item: try: return tensor.item() except AttributeError: return tensor[0] else: return tensor def is_tensor(object_): missed_tensor_classes = (torch.HalfTensor,) return torch.is_tensor(object_) or isinstance(object_, missed_tensor_classes) def is_label_tensor(object_): return is_tensor(object_) and object_.type() in ['torch.LongTensor', 'torch.cuda.LongTensor'] def is_image_tensor(object_): return is_tensor(object_) and object_.dim() == 4 def is_volume_tensor(object_): return is_tensor(object_) and object_.dim() == 5 def is_image_or_volume_tensor(object_): return is_image_tensor(object_) or is_volume_tensor(object_) def is_label_image_tensor(object_): return is_label_tensor(object_) and object_.dim() == 3 def is_label_volume_tensor(object_): return is_label_tensor(object_) and object_.dim() == 4 def is_label_image_or_volume_tensor(object_): return is_label_image_tensor(object_) or is_label_volume_tensor(object_) def is_matrix_tensor(object_): return is_tensor(object_) and object_.dim() == 2 def is_scalar_tensor(object_): return is_tensor(object_) and object_.dim() <= 1 and object_.numel() == 1 def is_vector_tensor(object_): return is_tensor(object_) and object_.dim() == 1 and object_.numel() > 1 def assert_same_size(tensor_1, tensor_2): assert_(list(tensor_1.size()) == list(tensor_2.size()), "Tensor sizes {} and {} do not match.".format(tensor_1.size(), tensor_2.size()), ShapeError) def where(condition, if_true, if_false): """ Torch equivalent of numpy.where. Parameters ---------- condition : torch.ByteTensor or torch.cuda.ByteTensor Condition to check. if_true : torch.Tensor or torch.cuda.Tensor Output value if condition is true. if_false: torch.Tensor or torch.cuda.Tensor Output value if condition is false Returns ------- torch.Tensor Raises ------ AssertionError if if_true and if_false don't have the same datatype. """ # noinspection PyArgumentList assert if_true.type() == if_false.type(), \ "Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type()) casted_condition = condition.type_as(if_true) output = casted_condition * if_true + (1 - casted_condition) * if_false return output def flatten_samples(input_): """ Flattens a tensor or a variable such that the channel axis is first and the sample axis is second. The shapes are transformed as follows: (N, C, H, W) --> (C, N * H * W) (N, C, D, H, W) --> (C, N * D * H * W) (N, C) --> (C, N) The input must be atleast 2d. """ assert_(input_.dim() >= 2, "Tensor or variable must be atleast 2D. Got one of dim {}." .format(input_.dim()), ShapeError) # Get number of channels num_channels = input_.size(1) # Permute the channel axis to first permute_axes = list(range(input_.dim())) permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] # For input shape (say) NCHW, this should have the shape CNHW permuted = input_.permute(*permute_axes).contiguous() # Now flatten out all but the first axis and return flattened = permuted.view(num_channels, -1) return flattened def clip_gradients_(parameters, mode, norm_or_value): assert_(mode in ['norm', 'value'], f"Mode must be 'norm' or 'value', got '{mode}' instead.", ValueError) if mode == 'norm': torch.nn.utils.clip_grad_norm_(parameters, norm_or_value) elif mode == 'value': torch.nn.utils.clip_grad_value_(parameters, norm_or_value) else: raise NotImplementedError ================================================ FILE: inferno/utils/train_utils.py ================================================ """Utilities for training.""" import numpy as np from .exceptions import assert_, FrequencyTypeError, FrequencyValueError class AverageMeter(object): """ Computes and stores the average and current value. Taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py """ def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class MovingAverage(object): """Computes the moving average of a given float.""" def __init__(self, momentum=0): self.momentum = momentum self.val = None self.previous = None def reset(self): self.val = None def update(self, val): self.previous = self.val if self.val is None: self.val = val else: self.val = self.momentum * self.val + (1 - self.momentum) * val return self.val @property def relative_change(self): if None not in [self.val, self.previous]: relative_change = (self.previous - self.val) / self.previous return relative_change else: return None class CLUI(object): """Command Line User Interface""" def __call__(self, f): def decorated(cls, *args, **kwargs): try: f(cls, *args, **kwargs) except KeyboardInterrupt: options_ = input("[!] Interrupted. Please select:\n" "[w] Save\n" "[d] Debug with PDB\n" "[q] Quit\n" "[c] Continue\n" "[?] >>> ") save_now = 'w' in options_ quit_now = 'q' in options_ debug_now = 'd' in options_ continue_now = 'c' in options_ or not quit_now if save_now: cls.save() if debug_now: print("[*] Firing up PDB. The trainer instance might be accessible as 'cls'.") import pdb pdb.set_trace() if quit_now: cls.print("Exiting.") raise SystemExit if continue_now: return return decorated class Frequency(object): def __init__(self, value=None, units=None): # Private self._last_match_value = None self._value = None self._units = None # Public self.value = value self.units = units @property def value(self): return self._value @value.setter def value(self, value): # If value is not being set, make sure the frequency never matches muhahaha if value is None or value == 'never': value = np.inf self.assert_value_consistent(value) self._value = value UNIT_PRIORITY = 'iterations' VALID_UNIT_NAME_MAPPING = {'iterations': 'iterations', 'iteration': 'iterations', 'epochs': 'epochs', 'epoch': 'epochs'} @property def units(self): return self._units @units.setter def units(self, value): if value is None: value = self.UNIT_PRIORITY self.assert_units_consistent(value) self._units = self.VALID_UNIT_NAME_MAPPING.get(value) def assert_value_consistent(self, value=None): value = value or self.value # Make sure that value is an integer or inf assert_(isinstance(value, (int, float)), "Value must be an integer or np.inf, got {} instead." .format(type(value).__name__), FrequencyTypeError) if isinstance(value, float): assert_(value == np.inf, "Provided value must be numpy.inf if a float, got {}.".format(value), FrequencyValueError) def assert_units_consistent(self, units=None): units = units or self.units # Map units = self.VALID_UNIT_NAME_MAPPING.get(units) assert_(units is not None, "Unit '{}' not understood.".format(units), FrequencyValueError) @property def is_consistent(self): try: self.assert_value_consistent() self.assert_units_consistent() return True except (FrequencyValueError, FrequencyTypeError): return False def epoch(self): self.units = 'epochs' return self def iteration(self): self.units = 'iterations' return self @property def by_epoch(self): return self.units == 'epochs' @property def by_iteration(self): return self.units == 'iterations' def every(self, value): self.value = value return self def match(self, iteration_count=None, epoch_count=None, persistent=False, match_zero=True): match_value = {'iterations': iteration_count, 'epochs': epoch_count}.get(self.units) if not match_zero and match_value == 0: match = False else: match = match_value is not None and \ self.value != np.inf and \ match_value % self.value == 0 if persistent and match and self._last_match_value == match_value: # Last matched value is the current matched value, i.e. we've matched once already, # and don't need to match again match = False if match: # Record current match value as the last known match value to maintain persistency self._last_match_value = match_value return match def __str__(self): return "{} {}".format(self.value, self.units) def __repr__(self): return "{}(value={}, units={})".format(type(self).__name__, self.value, self.units) @classmethod def from_string(cls, string): assert_(isinstance(string, str), "`string` must be a string, got {} instead." .format(type(string).__name__), TypeError) if string == 'never': return cls(np.inf, 'iterations') else: value_and_unit = string.split(' ') assert_(len(value_and_unit) == 2, "Was expecting a string 'value units' with one white-space " "between 'value' and 'units'.", ValueError) value, unit = value_and_unit value = np.inf if value == 'inf' else int(value) return cls(value, unit) @classmethod def build_from(cls, args, priority='iterations'): if isinstance(args, int): return cls(args, priority) elif isinstance(args, (tuple, list)): return cls(*args) elif isinstance(args, Frequency): return args elif isinstance(args, str): return cls.from_string(args) else: raise NotImplementedError class Duration(Frequency): """Like frequency, but measures a duration.""" def match(self, iteration_count=None, epoch_count=None, when_equal_return=False, **_): match_value = {'iterations': iteration_count, 'epochs': epoch_count}.get(self.units) assert_(match_value is not None, "Could not match duration because {} is not known.".format(self.units), ValueError) if match_value == self.value: return when_equal_return return match_value > self.value def compare(self, iteration_count=None, epoch_count=None): compare_value = {'iterations': iteration_count, 'epochs': epoch_count}.get(self.units) assert_(compare_value is not None, "Could not match duration because {} is not known.".format(self.units), ValueError) compared = {'iterations': None, 'epochs': None} compared.update({self.units: self.value - compare_value}) return compared def __sub__(self, other): assert_(isinstance(other, Duration), "Object of type {} cannot be subtracted from " "a Duration object.".format(type(other)), TypeError) assert_(other.units == self.units, "The Duration objects being subtracted must have the same units.", ValueError) return Duration(value=(self.value - other.value), units=self.units) class NoLogger(object): def __init__(self, logdir=None): self.logdir = logdir def log_value(self, *kwargs): pass def set_state(module, key, value): """Writes `key`-`value` pair to `module`'s state hook.""" if hasattr(module, '_state_hooks'): state_hooks = getattr(module, '_state_hooks') assert isinstance(state_hooks, dict), \ "State hook (i.e. module._state_hooks) is not a dictionary." state_hooks.update({key: value}) else: setattr(module, '_state_hooks', {key: value}) return module def get_state(module, key, default=None): """Gets key from `module`'s state hooks.""" return getattr(module, '_state_hooks', {}).get(key, default) ================================================ FILE: inferno/version.py ================================================ __version__ = '0.4.0' ================================================ FILE: readthedocs.yml ================================================ conda: file: docs/environment.yml python: version: 3.5 pip_install: false ================================================ FILE: requirements.txt ================================================ dill pyyaml scipy>=0.13.0 h5py numpy>=1.8 scikit-image ================================================ FILE: requirements_dev.txt ================================================ pip==8.1.2 bumpversion==0.5.3 wheel==0.29.0 watchdog==0.8.3 flake8==2.6.0 tox==2.3.1 coverage==4.1 Sphinx==1.4.8 cryptography==1.7 PyYAML==5.1 dill pyyaml scipy>=0.13.0 h5py scikit-image sphinx-gallery sphinxcontrib-napoleon sphinxcontrib-inlinesyntaxhighlight sphinx_rtd_theme ================================================ FILE: setup.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- """The setup script.""" from setuptools import setup, find_packages import runpy __version__ = runpy.run_path('inferno/version.py')['__version__'] with open('README.rst') as readme_file: readme = readme_file.read() with open('HISTORY.rst') as history_file: history = history_file.read() requirements = [ # TODO: put package requirements here "pip>=8.1.2", "torch>=0.1.12", "dill", "pyyaml", "scipy>=0.13.0", "h5py", "numpy>=1.8", "scikit-image", "torchvision", "tqdm" ] setup_requirements = [ 'pytest-runner' ] test_requirements = [ 'pytest', 'unittest' ] dependency_links = [ 'http://download.pytorch.org/whl/cu75/torch-0.2.0.post1-cp35-cp35m-manylinux1_x86_64.whl#egg=torch-0.2.0' ] setup( name='inferno-pytorch', version=__version__, description="Inferno is a little library providing utilities and convenience functions/classes around PyTorch.", long_description=readme + '\n\n' + history, author="Nasim Rahaman", author_email='nasim.rahaman@iwr.uni-heidelberg.de', url='https://github.com/inferno-pytorch/inferno', packages=find_packages(where='.', exclude=["*.tests", "*.tests.*", "tests.*", "tests", "__pycache__", "*.pyc"]), dependency_links=dependency_links, include_package_data=True, install_requires=requirements, license="Apache Software License 2.0", zip_safe=False, keywords='inferno pytorch torch deep learning cnn deep-pyromania', classifiers=[ # How mature is this project? Common values are\ # 2 - Pre-Alpha', # 3 - Alpha, # 4 - Beta, # 5 - Production/Stable 'Development Status :: 2 - Pre-Alpha', # Indicate who your project is intended for 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Natural Language :: English', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6' ], test_suite='test', tests_require=test_requirements, setup_requires=setup_requirements, ) ================================================ FILE: tests/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_extensions/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_extensions/test_containers/test_graph.py ================================================ import unittest from functools import reduce import torch class TestGraph(unittest.TestCase): def setUp(self): import torch.nn as nn from inferno.utils.python_utils import from_iterable class DummyNamedModule(nn.Module): def __init__(self, name, history, num_inputs=1): super(DummyNamedModule, self).__init__() self.name = name self.history = history self.num_inputs = num_inputs def forward(self, *inputs): assert len(inputs) == self.num_inputs self.history.append(self.name) if self.num_inputs > 1: output = reduce(lambda x, y: x + y, inputs) else: output = from_iterable(inputs) return output self.DummyNamedModule = DummyNamedModule # @unittest.skip def test_graph_dummy_basic(self): import torch from inferno.extensions.containers.graph import Graph if not hasattr(self, 'DummyNamedModule'): self.setUp() DummyNamedModule = self.DummyNamedModule history = [] # Build graph model = Graph() model.add_input_node('input_0') model.add_input_node('input_1') model.add_node('conv0_0', DummyNamedModule('conv0_0', history)) model.add_node('conv0_1', DummyNamedModule('conv0_1', history)) model.add_node('conv1', DummyNamedModule('conv1', history, 2)) model.add_node('conv2', DummyNamedModule('conv2', history)) model.add_output_node('output_0') model.add_edge('input_0', 'conv0_0')\ .add_edge('input_1', 'conv0_1')\ .add_edge('conv0_0', 'conv1')\ .add_edge('conv0_1', 'conv1')\ .add_edge('conv1', 'conv2')\ .add_edge('conv2', 'output_0') input_0 = torch.rand(10, 10) input_1 = torch.rand(10, 10) model(input_0, input_1) self.assertTrue(history == ['conv0_0', 'conv0_1', 'conv1', 'conv2'] or history == ['conv0_1', 'conv0_0', 'conv1', 'conv2']) # @unittest.skip def test_graph_dummy_inception(self): import torch from inferno.extensions.containers.graph import Graph if not hasattr(self, 'DummyNamedModule'): self.setUp() DummyNamedModule = self.DummyNamedModule history = [] # Build graph model = Graph() model.add_input_node('input_0') model.add_node('conv0', DummyNamedModule('conv0', history), 'input_0') model.add_node('conv1_0', DummyNamedModule('conv1_0', history), 'conv0') model.add_node('conv1_1', DummyNamedModule('conv1_1', history), 'conv0') model.add_node('conv2', DummyNamedModule('conv2', history, 2), ['conv1_0', 'conv1_1']) model.add_output_node('output_0', 'conv2') input_0 = torch.rand(10, 10) model(input_0) self.assertTrue(history == ['conv0', 'conv1_0', 'conv1_1', 'conv2'] or history == ['conv0', 'conv1_1', 'conv1_2', 'conv2']) # @unittest.skip def test_graph_basic(self): from inferno.extensions.containers.graph import Graph from inferno.extensions.layers.convolutional import ConvELU2D from inferno.utils.model_utils import ModelTester # Build graph model = Graph() model.add_input_node('input_0') model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0') model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0') model.add_output_node('output_0', previous='conv1') ModelTester((1, 1, 100, 100), (1, 1, 100, 100))(model) @unittest.skipUnless(torch.cuda.is_available(), "No cuda.") def test_graph_device_transfers(self): from inferno.extensions.containers.graph import Graph from inferno.extensions.layers.convolutional import ConvELU2D import torch # Build graph model = Graph() model.add_input_node('input_0') model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0') model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0') model.add_output_node('output_0', previous='conv1') # Transfer model.to_device('conv0', 'cpu').to_device('conv1', 'cuda', 0) x = torch.rand(1, 1, 100, 100) y = model(x) self.assertIsInstance(y.data, torch.cuda.FloatTensor) @unittest.skip("Needs machine with 4 GPUs") def test_multi_gpu(self): import torch import torch.nn as nn from torch.nn.parallel.data_parallel import data_parallel from inferno.extensions.containers.graph import Graph input_shape = [8, 1, 3, 128, 128] model = Graph() \ .add_input_node('input') \ .add_node('conv0', nn.Conv3d(1, 10, 3, padding=1), previous='input') \ .add_node('conv1', nn.Conv3d(10, 1, 3, padding=1), previous='conv0') \ .add_output_node('output', previous='conv1') model.cuda() input = torch.rand(*input_shape).cuda() data_parallel(model, input, device_ids=[0, 1, 2, 3]) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_criteria/test_core.py ================================================ import unittest import torch import torch.nn as nn class TestCore(unittest.TestCase): def test_as_2d_criterion(self): from inferno.extensions.criteria.core import As2DCriterion prediction = torch.FloatTensor(2, 10, 100, 100).uniform_() prediction = nn.Softmax2d()(prediction) target = torch.LongTensor(2, 100, 100).fill_(0) criterion = As2DCriterion(nn.CrossEntropyLoss()) criterion(prediction, target) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_criteria/test_elementwise_measures.py ================================================ import unittest import inferno.extensions.criteria.elementwise_measures as em import torch class TestElementwiseMeasures(unittest.TestCase): def test_weighted_mse_loss(self): input = torch.zeros(10, 10) target = torch.ones(10, 10) loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target) self.assertAlmostEqual(loss.item(), 2., delta=1e-5) target = torch.zeros(10, 10) input = torch.ones(10, 10) loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target) self.assertAlmostEqual(loss.item(), 1., delta=1e-5) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_criteria/test_set_similarity_measures.py ================================================ import unittest import torch class SetSimilarityTest(unittest.TestCase): def get_dummy_variables(self): x = torch.zeros(3, 2, 100, 100).uniform_() y = torch.zeros(3, 2, 100, 100).uniform_() return x, y def get_dummy_variables_with_channels_and_classes(self): # (batch_size, channels, classes, ...) x = torch.zeros(3, 2, 5, 100, 100).uniform_() y = torch.zeros(3, 2, 5, 100, 100).uniform_() return x, y class TestSorensenDice(SetSimilarityTest): # noinspection PyCallingNonCallable def test_channelwise(self): from inferno.extensions.criteria.set_similarity_measures import SorensenDiceLoss x, y = self.get_dummy_variables() channelwise = SorensenDiceLoss(channelwise=True) not_channelwise = SorensenDiceLoss(channelwise=False) # Compute expected channelwise loss expected_channelwise_loss = \ not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \ not_channelwise(x[:, 1, ...], y[:, 1, ...]) # Compute channelwise channelwise_loss = channelwise(x, y) # Compare self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item()) class TestGeneralizedSorensenDice(SetSimilarityTest): def test_channelwise(self): from inferno.extensions.criteria.set_similarity_measures import GeneralizedDiceLoss x, y = self.get_dummy_variables_with_channels_and_classes() channelwise = GeneralizedDiceLoss(channelwise=True) not_channelwise = GeneralizedDiceLoss(channelwise=False) # Compute channelwise loss and expected one: channelwise_loss = channelwise(x, y) expected_channelwise_loss = \ not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \ not_channelwise(x[:, 1, ...], y[:, 1, ...]) # Compare self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item()) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_layers/deprecated/building_blocks.py ================================================ import unittest import torch import inferno.extensions.layers.building_blocks as bb class ResBlockTest(unittest.TestCase): def test_2D_simple_(self): x = torch.rand(1, 3, 64, 15) model = bb.ResBlock(in_channels=3, out_channels=3, dim=2) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,3, 64, 15]) def test_3D_simple_(self): x = torch.rand(1,3,20, 64,15) model = bb.ResBlock(in_channels=3, out_channels=3, dim=3) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,3, 20, 64, 15]) def test_2D_simple_2(self): x = torch.rand(1,3,64,64) model = bb.ResBlock(in_channels=3, out_channels=6, dim=2) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,6, 64, 64]) def test_2D_simple_3(self): x = torch.rand(1,3,64,64) model = bb.ResBlock(in_channels=3, out_channels=6, dim=2, size=4) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,6, 64, 64]) def test_2D_simple_4(self): x = torch.rand(1,6,64,64) model = bb.ResBlock(in_channels=6, out_channels=6, dim=2, size=4, force_skip_op=True) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,6, 64, 64]) def test_2D_simple_5(self): x = torch.rand(1,6,64,64) model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4, force_skip_op=True) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,6, 64, 64]) def test_2D_simple_6(self): x = torch.rand(1,6,64,64) model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4, force_skip_op=True, activated=False) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,6, 64, 64]) def test_3D_simple_6(self): x = torch.rand(1,6,64,64, 20) model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=3, size=4, force_skip_op=True, activated=False) xx = model(x) out_size = xx.size() self.assertEqual(list(out_size), [1,6, 64, 64, 20]) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_layers/test_activations.py ================================================ import unittest import torch import inferno.extensions.layers.activations as activations class ActivationTest(unittest.TestCase): def test_selu(self): x = torch.rand(100) y = activations.SELU()(x) self.assertEqual(list(x.size()), list(y.size())) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_layers/test_convolutional.py ================================================ import unittest import torch from inferno.utils.model_utils import ModelTester class TestConvolutional(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.") def test_bn_relu_depthwise_conv2d_pyinn(self): from inferno.extensions.layers.convolutional import BNReLUDepthwiseConv2D model = BNReLUDepthwiseConv2D(10, 'auto', 3) ModelTester((1, 10, 100, 100), (1, 10, 100, 100)).cuda()(model) self.assertTrue(model.depthwise) self.assertEqual(model.conv.groups, 10) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_layers/test_device.py ================================================ import unittest from inferno.extensions.layers.device import DeviceTransfer, OnDevice import torch class TransferTest(unittest.TestCase): @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.") def test_device_transfer(self): if not torch.cuda.is_available(): return # Build transfer model transfer = DeviceTransfer('cpu') x = torch.rand(10, 10).cuda() y = transfer(x) loss = y.mean() loss.backward() self.assertFalse(y.data.is_cuda) self.assertIsNotNone(x.grad) self.assertTrue(x.grad.data.is_cuda) @unittest.skipIf(not torch.cuda.is_available(), "GPU not available.") def test_on_device(self): if not torch.cuda.is_available(): return # Build variable on the GPU x = torch.rand(1, 10) # Build model over multiple devices multi_device_model = torch.nn.Sequential(OnDevice(torch.nn.Linear(10, 10), 'cuda'), OnDevice(torch.nn.Linear(10, 10), 'cpu')) y = multi_device_model(x) self.assertIsInstance(y.data, torch.FloatTensor) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_layers/test_reshape.py ================================================ import unittest import torch class TestReshape(unittest.TestCase): def _get_input_variable(self, *shape): return torch.rand(*shape) def test_as_matrix(self): from inferno.extensions.layers.reshape import AsMatrix input = self._get_input_variable(10, 20, 1, 1) as_matrix = AsMatrix() output = as_matrix(input) self.assertEqual(list(output.size()), [10, 20]) def test_flatten(self): from inferno.extensions.layers.reshape import Flatten input = self._get_input_variable(10, 20, 2, 2) flatten = Flatten() output = flatten(input) self.assertEqual(list(output.size()), [10, 80]) def test_as_2d(self): from inferno.extensions.layers.reshape import As2D as_2d = As2D() output_shape = as_2d(self._get_input_variable(10, 20, 3, 30, 30)).size() self.assertEqual(list(output_shape), [10, 60, 30, 30]) output_shape = as_2d(self._get_input_variable(10, 20, 30, 30)).size() self.assertEqual(list(output_shape), [10, 20, 30, 30]) output_shape = as_2d(self._get_input_variable(10, 20)).size() self.assertEqual(list(output_shape), [10, 20, 1, 1]) def test_as_3d(self): from inferno.extensions.layers.reshape import As3D from inferno.utils.exceptions import ShapeError as_3d = As3D() output_shape = as_3d(self._get_input_variable(10, 20, 3, 30, 30)).size() self.assertEqual(list(output_shape), [10, 20, 3, 30, 30]) output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size() self.assertEqual(list(output_shape), [10, 20, 1, 30, 30]) output_shape = as_3d(self._get_input_variable(10, 20)).size() self.assertEqual(list(output_shape), [10, 20, 1, 1, 1]) as_3d.channel_as_z = True output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size() self.assertEqual(list(output_shape), [10, 1, 20, 30, 30]) as_3d.num_channels_or_num_z_slices = 2 output_shape = as_3d(self._get_input_variable(10, 40, 30, 30)).size() self.assertEqual(list(output_shape), [10, 2, 20, 30, 30]) with self.assertRaises(ShapeError): output_shape = as_3d(self._get_input_variable(10, 41, 30, 30)).size() self.assertEqual(list(output_shape), [10, 2, 20, 30, 30]) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_metrics/categorical.py ================================================ import unittest import torch from inferno.extensions.metrics import IOU class TestCategorical(unittest.TestCase): def test_iou_basic(self): # from one hot predicted_image = torch.zeros(*(2, 10, 10)) predicted_image[:, 0:4, 0:4] = 1 target_image = torch.zeros(*(2, 10, 10)) target_image[:, 0:3, 0:3] = 1 expected_iou = (3 * 3)/(4 * 4) iou = IOU()(predicted_image[None, ...], target_image[None, ...]) self.assertAlmostEqual(iou, expected_iou, places=4) def test_iou_with_ignore_class(self): predicted_image = torch.zeros(*(2, 10, 10)) predicted_image[0, 0:4, 0:4] = 1 target_image = torch.zeros(*(2, 10, 10)) target_image[:, 0:3, 0:3] = 1 expected_iou = (3 * 3) / (4 * 4) iou = IOU(ignore_class=1)(predicted_image[None, ...], target_image[None, ...]) self.assertAlmostEqual(iou, expected_iou, places=4) def test_multiclass_iou(self): predicted_image = torch.zeros(*(2, 10, 10)) predicted_image[0, 0:4, 0:4] = 1 target_image = torch.zeros(*(2, 10, 10)) target_image[:, 0:3, 0:3] = 1 iou_class_0 = (3 * 3) / (4 * 4) iou_class_1 = 0 expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1) iou = IOU()(predicted_image[None, ...], target_image[None, ...]) self.assertAlmostEqual(iou, expected_mean_iou, places=4) def test_multiclass_iou_with_ignore_class(self): predicted_image = torch.zeros(*(3, 10, 10)) predicted_image[0, 0:4, 0:4] = 1 # Have the third plane be crap predicted_image[2, :, :] = 1 target_image = torch.zeros(*(3, 10, 10)) target_image[:, 0:3, 0:3] = 1 iou_class_0 = (3 * 3) / (4 * 4) iou_class_1 = 0 expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1) iou = IOU(ignore_class=-1)(predicted_image[None, ...], target_image[None, ...]) self.assertAlmostEqual(iou, expected_mean_iou, places=4) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_models/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_extensions/test_models/test_res_unet.py ================================================ import unittest import torch import torch.cuda as cuda from inferno.utils.model_utils import ModelTester class ResUNetTest(unittest.TestCase): def test_res_unet_2d(self): from inferno.extensions.models import ResBlockUNet tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256)) if cuda.is_available(): tester.cuda() tester(ResBlockUNet(in_channels=1, out_channels=1, dim=2)) def test_res_unet_3d(self): from inferno.extensions.models import ResBlockUNet tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64)) if cuda.is_available(): tester.cuda() # test default unet 3d tester(ResBlockUNet(in_channels=1, out_channels=1, dim=3)) def test_2d_side_out_bot_up(self): from inferno.extensions.models import ResBlockUNet depth = 3 in_channels = 3 x = torch.rand(1, in_channels, 64, 32) model = ResBlockUNet(in_channels=in_channels, out_channels=8, dim=2, side_out_parts=['bottom','up'], unet_kwargs=dict(depth=depth)) out_list = model(x) self.assertEqual(len(out_list), depth + 1) self.assertEqual(list(out_list[0].size()), [1, 24, 8, 4]) self.assertEqual(list(out_list[1].size()), [1, 12, 16, 8]) self.assertEqual(list(out_list[2].size()), [1, 6, 32, 16]) self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32]) def test_2d_side_out_up(self): from inferno.extensions.models import ResBlockUNet depth = 3 in_channels = 3 x = torch.rand(1, in_channels, 64, 32) model = ResBlockUNet(in_channels=in_channels, out_channels=8, dim=2, side_out_parts=['up'], unet_kwargs=dict(depth=depth)) out_list = model(x) self.assertEqual(len(out_list), depth) self.assertEqual(list(out_list[0].size()), [1,12, 16, 8]) self.assertEqual(list(out_list[1].size()), [1, 6, 32, 16]) self.assertEqual(list(out_list[2].size()), [1, 8, 64, 32]) def test_2d_side_out_down(self): from inferno.extensions.models import ResBlockUNet depth = 3 in_channels = 3 x = torch.rand(1, in_channels, 64, 32) model = ResBlockUNet(in_channels=in_channels, out_channels=8, dim=2, side_out_parts=['down'], unet_kwargs=dict(depth=depth)) out_list = model(x) self.assertEqual(len(out_list), depth + 1) self.assertEqual(list(out_list[0].size()), [1, 6, 64, 32]) self.assertEqual(list(out_list[1].size()), [1, 12, 32, 16]) self.assertEqual(list(out_list[2].size()), [1, 24, 16, 8]) # the actual output self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32]) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_extensions/test_models/test_unet.py ================================================ import unittest import torch.cuda as cuda from inferno.utils.model_utils import ModelTester, MultiscaleModelTester from inferno.extensions.models import UNet class _MultiscaleUNet(UNet): def conv_op_factory(self, in_channels, out_channels, part, index): return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True def forward(self, input): x = self._initial_conv(input) x = list(super(UNet, self).forward(x)) x[-1] = self._output(x[-1]) return tuple(x) class UNetTest(unittest.TestCase): def test_unet_2d(self): tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256)) if cuda.is_available(): tester.cuda() tester(UNet(1, 1, dim=2, initial_features=32)) def test_unet_3d(self): tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64)) if cuda.is_available(): tester.cuda() # test default unet 3d tester(UNet(1, 1, dim=3, initial_features=8)) def test_monochannel_unet_3d(self): nc = 2 class _UNetMonochannel(_MultiscaleUNet): def _get_num_channels(self, depth): return nc shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4), (1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)] tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes) if cuda.is_available(): tester.cuda() tester(_UNetMonochannel(1, 1, dim=3, initial_features=8)) def test_inverse_pyramid_unet_2d(self): class _UNetInversePyramid(_MultiscaleUNet): def _get_num_channels(self, depth): return [13, 12, 11][depth - 1] shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8), (1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)] tester = MultiscaleModelTester((1, 1, 16, 64), shapes) if cuda.is_available(): tester.cuda() tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8)) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_inferno.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- """Tests for `inferno` package.""" import unittest import numpy as np import torch import os import shutil from os.path import dirname, join from torch.utils.data.dataset import TensorDataset from torch.utils.data.dataloader import DataLoader from inferno.extensions.layers import Conv2D, BNReLUConv2D from inferno.extensions.layers import AsMatrix from inferno.extensions.containers import Graph from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.essentials import NaNDetector from inferno.trainers.callbacks.base import Callback from torch import nn class TestInferno(unittest.TestCase): """Tests for `inferno` package.""" NUM_SAMPLES = 100 NUM_TRAINING_SAMPLES = 70 NUM_CLASSES = 10 WORKING_DIRECTORY = dirname(__file__) def read_environment_variables(self): self.NUM_SAMPLES = int(os.getenv('INFERNO_TEST_NUM_SAMPLES', str(self.NUM_SAMPLES))) self.NUM_TRAINING_SAMPLES = int(os.getenv('INFERNO_TEST_NUM_SAMPLES', str(self.NUM_TRAINING_SAMPLES))) self.NUM_CLASSES = int(os.getenv('INFERNO_TEST_NUM_CLASSES', str(self.NUM_CLASSES))) self.WORKING_DIRECTORY = os.getenv('INFERNO_TEST_WORKING_DIRECTORY', self.WORKING_DIRECTORY) def setUp(self): """Set up test fixtures, if any.""" self.setUpDatasets() def setUpDatasets(self): # Build training dataset inputs, targets = self.generate_random_data(self.NUM_SAMPLES, (3, 32, 32), num_classes=self.NUM_CLASSES, dtype='float32') # Split to train and split train_inputs, train_targets = inputs[:self.NUM_TRAINING_SAMPLES], \ targets[:self.NUM_TRAINING_SAMPLES] validate_inputs, validate_targets = inputs[self.NUM_TRAINING_SAMPLES:], \ targets[self.NUM_TRAINING_SAMPLES:] # Convert to tensor and build dataset train_dataset = TensorDataset(torch.from_numpy(train_inputs), torch.from_numpy(train_targets)) validate_dataset = TensorDataset(torch.from_numpy(validate_inputs), torch.from_numpy(validate_targets)) # Build dataloaders from dataset self.train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=False) self.validate_loader = DataLoader(validate_dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=False) def setUpCallbacks(self): class RecordSaveInfo(Callback): def __init__(self): super(RecordSaveInfo, self).__init__() self.best_saves_at_iteration_epoch = [] self.saves_at_iteration_epoch = [] def begin_of_save(self, epoch_count, iteration_count, is_iteration_with_best_validation_score, **_): if is_iteration_with_best_validation_score: self.best_saves_at_iteration_epoch.append((iteration_count, epoch_count)) else: self.saves_at_iteration_epoch.append((iteration_count, epoch_count)) self.RecordSaveInfo = RecordSaveInfo def generate_random_data(self, num_samples, shape, num_classes, hardness=0.3, dtype=None): dataset_input = np.zeros((num_samples,) + shape, dtype=dtype) dataset_target = np.random.randint(num_classes, size=num_samples) for sample_num in range(num_samples): dataset_input[sample_num] = np.random.normal(loc=dataset_target[sample_num], scale=(1 - hardness), size=shape) return dataset_input, dataset_target def tearDown(self): """Tear down test fixtures, if any.""" if os.path.exists(join(self.WORKING_DIRECTORY, 'Weights')): shutil.rmtree(join(self.WORKING_DIRECTORY, 'Weights')) def build_graph_model(self): model = Graph() model\ .add_input_node('input')\ .add_node('conv1', Conv2D(3, 8, 3), 'input')\ .add_node('conv2', BNReLUConv2D(8, 8, 3), 'conv1')\ .add_node('pool1', nn.MaxPool2d(kernel_size=2, stride=2), 'conv2')\ .add_node('conv3', BNReLUConv2D(8, 8, 3), 'pool1')\ .add_node('pool2', nn.MaxPool2d(kernel_size=2, stride=2), 'conv3')\ .add_node('conv4', BNReLUConv2D(8, 8, 3), 'pool2')\ .add_node('pool3', nn.AdaptiveAvgPool2d(output_size=(1, 1)), 'conv4')\ .add_node('matrix', AsMatrix(), 'pool3')\ .add_node('linear', nn.Linear(8, self.NUM_CLASSES), 'matrix')\ .add_output_node('output', 'linear') return model def test_training_cpu(self): """Test Trainer.""" # Build model model = self.build_graph_model() # Build callbacks # save_info_recorder = RecordSaveInfo() # Build trainer trainer = Trainer(model)\ .save_every((2, 'epochs'), to_directory=join(self.WORKING_DIRECTORY, 'Weights'))\ .validate_every((100, 'iterations'), for_num_iterations=10)\ .set_max_num_epochs(4)\ .save_at_best_validation_score()\ .build_optimizer('RMSprop')\ .build_criterion('CrossEntropyLoss')\ .build_metric('CategoricalError')\ .register_callback(NaNDetector) # Bind datasets trainer\ .bind_loader('train', self.train_loader)\ .bind_loader('validate', self.validate_loader) # Go trainer.pickle_module = 'dill' trainer.fit() if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_io/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_io/test_box/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_io/test_box/test_camvid.py ================================================ import os from os.path import join, dirname, exists, isdir import unittest import numpy as np _CAMVID_ROOT = None def _camvid_available(): return _CAMVID_ROOT is not None or os.environ.get('CAMVID_ROOT') is not None class TestCamvid(unittest.TestCase): CAMVID_ROOT = _CAMVID_ROOT PLOT_DIRECTORY = join(dirname(__file__), 'plots') def get_camvid_root(self): if self.CAMVID_ROOT is None: root = os.environ.get('CAMVID_ROOT') assert root is not None, "Camvid Root not found." else: return self.CAMVID_ROOT @unittest.skipUnless(_camvid_available(), "No root available.") def test_camvid_dataset_without_transforms(self): from inferno.io.box.camvid import CamVid camvid = CamVid(self.get_camvid_root()) image, label = camvid[0] image = np.asarray(image) label = np.asarray(label) self.assertSequenceEqual(image.shape, (360, 480, 3)) self.assertSequenceEqual(label.shape, (360, 480)) self.assertLessEqual(label.max(), 11) @unittest.skipUnless(_camvid_available(), "No root available.") def _test_camvid_dataset_with_transforms(self): from inferno.io.box.camvid import CamVid from inferno.io.transform.base import Compose from inferno.io.transform.image import PILImage2NumPyArray, RandomSizedCrop, Scale from inferno.utils.io_utils import print_tensor camvid = CamVid(self.get_camvid_root(), image_transform=Compose(), label_transform=Compose(), joint_transform=Compose()) camvid.image_transform.add(PILImage2NumPyArray()) camvid.label_transform.add(PILImage2NumPyArray()) image, label = camvid[0] self.assertSequenceEqual(image.shape, (3, 360, 480)) self.assertSequenceEqual(label.shape, (360, 480)) # Add crop trafo camvid.joint_transform.add(RandomSizedCrop(ratio_between=(0.7, 1.0), preserve_aspect_ratio=True)) # We need 2 scale transforms, one with order 3 (image) and the other with order 0 (label) camvid.joint_transform.add(Scale(output_image_shape=(360, 480), interpolation_order=3, apply_to=[0])) camvid.joint_transform.add(Scale(output_image_shape=(360, 480), interpolation_order=0, apply_to=[1])) image, label = camvid[0] self.assertSequenceEqual(image.shape, (3, 360, 480)) self.assertSequenceEqual(label.shape, (360, 480)) self.assertLessEqual(len(np.unique(label)), 12) # Print tensors to make sure they look legit if not exists(self.PLOT_DIRECTORY): os.mkdir(self.PLOT_DIRECTORY) else: assert isdir(self.PLOT_DIRECTORY) print_tensor(image[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) print_tensor(label[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) @unittest.skipUnless(_camvid_available(), "No root available.") def test_camvid_dataset_with_transforms(self): from inferno.io.box.camvid import get_camvid_loaders from inferno.utils.io_utils import print_tensor train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root()) train_dataset = train_loader.dataset image, label = train_dataset[0] # Make sure the shapes checkout self.assertSequenceEqual(image.size(), (3, 360, 480)) self.assertSequenceEqual(label.size(), (360, 480)) self.assertEqual(image.type(), 'torch.FloatTensor') self.assertEqual(label.type(), 'torch.LongTensor') # Print tensors to make sure they look legit if not exists(self.PLOT_DIRECTORY): os.mkdir(self.PLOT_DIRECTORY) else: assert isdir(self.PLOT_DIRECTORY) print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) print_tensor(label.numpy()[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) @unittest.skipUnless(_camvid_available(), "No root available.") def test_camvid_dataset_with_transforms_onehot(self): from inferno.io.box.camvid import get_camvid_loaders from inferno.utils.io_utils import print_tensor train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root(), labels_as_onehot=True) train_dataset = train_loader.dataset image, label = train_dataset[0] # Make sure the shapes checkout self.assertSequenceEqual(image.size(), (3, 360, 480)) self.assertSequenceEqual(label.size(), (12, 360, 480)) self.assertEqual(image.type(), 'torch.FloatTensor') self.assertEqual(label.type(), 'torch.FloatTensor') # Print tensors to make sure they look legit if not exists(self.PLOT_DIRECTORY): os.mkdir(self.PLOT_DIRECTORY) else: assert isdir(self.PLOT_DIRECTORY) print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) print_tensor(label.numpy()[None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_io/test_box/test_cityscapes.py ================================================ import os from os.path import join, dirname, exists, isdir import unittest import numpy as np import time _CITYSCAPES_ROOT = None def _cityscapes_available(): return _CITYSCAPES_ROOT is not None or os.environ.get('CITYSCAPES_ROOT') is not None class TestCityscapes(unittest.TestCase): CITYSCAPES_ROOT = _CITYSCAPES_ROOT PLOT_DIRECTORY = join(dirname(__file__), 'plots') INCLUDE_COARSE = False def get_cityscapes_root(self): if self.CITYSCAPES_ROOT is None: root = os.environ.get('CITYSCAPES_ROOT') assert root is not None, "Cityscapes Root not found." else: return self.CITYSCAPES_ROOT @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") def test_cityscapes_dataset_without_transforms(self): from inferno.io.box.cityscapes import Cityscapes cityscapes = Cityscapes(self.get_cityscapes_root()) image, label = cityscapes[0] image = np.asarray(image) label = np.asarray(label) self.assertSequenceEqual(image.shape, (1024, 2048, 3)) self.assertSequenceEqual(label.shape, (1024, 2048)) self.assertLessEqual(label.max(), 33) @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") def test_cityscapes_dataset_without_transforms_unzipped(self): from inferno.io.box.cityscapes import Cityscapes cityscapes = Cityscapes(join(self.get_cityscapes_root(), 'extracted'), read_from_zip_archive=False) image, label = cityscapes[0] image = np.asarray(image) label = np.asarray(label) self.assertSequenceEqual(image.shape, (1024, 2048, 3)) self.assertSequenceEqual(label.shape, (1024, 2048)) self.assertLessEqual(label.max(), 33) @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") def test_cityscapes_dataset_with_transforms(self): from inferno.io.box.cityscapes import get_cityscapes_loaders from inferno.utils.io_utils import print_tensor train_loader, validate_loader = get_cityscapes_loaders(self.get_cityscapes_root(), include_coarse_dataset=self.INCLUDE_COARSE) train_dataset = train_loader.dataset tic = time.time() image, label = train_dataset[0] toc = time.time() print("[+] Loaded sample in {} seconds.".format(toc - tic)) # Make sure the shapes checkout self.assertSequenceEqual(image.size(), (3, 1024, 2048)) self.assertSequenceEqual(label.size(), (1024, 2048)) self.assertEqual(image.type(), 'torch.FloatTensor') self.assertEqual(label.type(), 'torch.LongTensor') # Print tensors to make sure they look legit if not exists(self.PLOT_DIRECTORY): os.mkdir(self.PLOT_DIRECTORY) else: assert isdir(self.PLOT_DIRECTORY) print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) for class_id in np.unique(label.numpy()): print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'), prefix='LAB-{}--'.format(class_id), directory=self.PLOT_DIRECTORY) print_tensor(label.numpy()[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) @unittest.skipUnless(_cityscapes_available(), "No cityscapes available.") def test_cityscapes_dataset_with_transforms_unzipped(self): from inferno.io.box.cityscapes import get_cityscapes_loaders from inferno.utils.io_utils import print_tensor train_loader, validate_loader = get_cityscapes_loaders(join(self.get_cityscapes_root(), 'extracted'), include_coarse_dataset=self.INCLUDE_COARSE, read_from_zip_archive=False) train_dataset = train_loader.dataset tic = time.time() image, label = train_dataset[0] toc = time.time() print("[+] Loaded sample in {} seconds.".format(toc - tic)) # Make sure the shapes checkout self.assertSequenceEqual(image.size(), (3, 1024, 2048)) self.assertSequenceEqual(label.size(), (1024, 2048)) self.assertEqual(image.type(), 'torch.FloatTensor') self.assertEqual(label.type(), 'torch.LongTensor') # Print tensors to make sure they look legit if not exists(self.PLOT_DIRECTORY): os.mkdir(self.PLOT_DIRECTORY) else: assert isdir(self.PLOT_DIRECTORY) print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY) for class_id in np.unique(label.numpy()): print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'), prefix='LAB-{}--'.format(class_id), directory=self.PLOT_DIRECTORY) print_tensor(label.numpy()[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY) print("[+] Inspect images at {}".format(self.PLOT_DIRECTORY)) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_io/test_core/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_io/test_core/test_concatenate.py ================================================ import unittest class ConcatenateTest(unittest.TestCase): def test_concatenate(self): from inferno.io.core import Concatenate from torch.utils.data.dataset import Dataset with self.assertRaises(AssertionError): cated = Concatenate([1, 2, 3], [4, 5, 6, 7]) class ListDataset(list, Dataset): pass dataset_1 = ListDataset([1, 2, 3, 4]) dataset_2 = ListDataset([5, 6, 7]) dataset_3 = ListDataset([8, 9, 10, 11, 12]) cated = Concatenate(dataset_1, dataset_2, dataset_3) self.assertEqual(len(cated), 12) # Try to fetch self.assertEqual(cated[2], 3) self.assertEqual(cated[4], 5) self.assertEqual(cated[6], 7) self.assertEqual(cated[10], 11) self.assertEqual(cated[11], 12) with self.assertRaises(AssertionError): _ = cated[12] if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_io/test_core/test_zip.py ================================================ import unittest class ZipTest(unittest.TestCase): def test_zip_minimal(self): """Minimal test with python lists as iterators.""" from inferno.io.core import Zip from torch.utils.data.dataset import Dataset with self.assertRaises(TypeError): zipped = Zip([1, 2, 3], [4, 5, 6, 7]) # This is required because Zip checks if its inputs are actually torch datasets class ListDataset(list, Dataset): pass dataset_1 = ListDataset([1, 2, 3, 4]) dataset_2 = ListDataset([5, 6, 7, 8, 9]) zipped = Zip(dataset_1, dataset_2) self.assertEqual(len(zipped), 4) fetched = zipped[1] self.assertEqual(fetched, [2, 6]) with self.assertRaises(IndexError): fetched = zipped[4] def test_zip_sync(self): """Test synchronization mechanics.""" # TODO def test_zip_reject(self): from inferno.io.core import ZipReject from torch.utils.data.dataset import Dataset # This is required because Zip checks if its inputs are actually torch datasets class ListDataset(list, Dataset): pass def rejection_criterion(sample_1, sample_2): return sample_1 < sample_2 dataset_1 = ListDataset([1, 2, 3, 4]) dataset_2 = ListDataset([2, 1, 3, 4]) dataset_3 = ListDataset([0, 1, 2, 3]) zipped = ZipReject(dataset_1, dataset_2, dataset_3, rejection_criterion=rejection_criterion, random_jump_after_reject=False, rejection_dataset_indices=[0, 1]) fetched = zipped[0] self.assertSequenceEqual(fetched, [2, 1, 1]) zipped = ZipReject(dataset_1, dataset_2, dataset_3, rejection_criterion=rejection_criterion, rejection_dataset_indices=[1, 0]) fetched = zipped[0] self.assertSequenceEqual(fetched, [1, 2, 0]) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_io/test_volumetric/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_io/test_volumetric/test_lazy_volume_loader.py ================================================ import unittest import os import numpy as np # try to load io libraries (h5py and z5py) try: import h5py WITH_H5PY = True except ImportError: WITH_H5PY = False # try: # import z5py # WITH_Z5PY = True # except ImportError: # WITH_Z5PY = False class TestLazyVolumeLoader(unittest.TestCase): def tearDown(self): try: os.remove('tmp.h5') except OSError: pass @unittest.skipUnless(WITH_H5PY, "Need h5py") def test_h5_loader(self): from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader shape = (100, 100) # test default data loader data = np.arange(np.product(shape)).reshape(shape) with h5py.File('tmp.h5') as f: f.create_dataset('data', data=data) loader = LazyHDF5VolumeLoader('tmp.h5', 'data', window_size=[10, 10], stride=[10, 10], return_index_spec=True) self.assertEqual(loader.shape, shape) for batch, index in loader: expected = data[index.base_sequence_at_index] self.assertEqual(batch.shape, expected.shape) self.assertTrue(np.allclose(batch, expected)) @unittest.skipUnless(WITH_H5PY, "Need h5py") def test_h5_loader_data_slice(self): from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader shape = (100, 100, 100) data_slice = np.s_[:, 20:80, 10:30] # test default data loader data = np.arange(np.product(shape)).reshape(shape) with h5py.File('tmp.h5') as f: f.create_dataset('data', data=data) data = data[data_slice] loader = LazyHDF5VolumeLoader('tmp.h5', 'data', window_size=[10, 10, 10], stride=[10, 10, 10], return_index_spec=True, data_slice=data_slice) self.assertEqual(loader.shape, data.shape) for batch, index in loader: slice_ = index.base_sequence_at_index expected = data[slice_] self.assertEqual(batch.shape, expected.shape) self.assertTrue(np.allclose(batch, expected)) @unittest.skipUnless(WITH_H5PY, "Need h5py") def test_h5_loader_pad(self): from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader shape = (100, 100, 100) pad = [[0, 10], [0, 0], [5, 15]] # test default data loader data = np.arange(np.product(shape)).reshape(shape) with h5py.File('tmp.h5') as f: f.create_dataset('data', data=data) data = np.pad(data, pad_width=pad, mode='constant') loader = LazyHDF5VolumeLoader('tmp.h5', 'data', window_size=[20, 20, 20], stride=[20, 20, 20], return_index_spec=True, padding=pad, padding_mode='constant') self.assertEqual(loader.shape, data.shape) for batch, index in loader: slice_ = index.base_sequence_at_index expected = data[slice_] self.assertEqual(batch.shape, expected.shape) self.assertTrue(np.allclose(batch, expected)) @unittest.skipUnless(WITH_H5PY, "Need h5py") def test_h5_loader_data_slice_pad(self): from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader shape = (100, 100, 100) data_slice = np.s_[:, 20:80, 10:90] pad = [[0, 10], [5, 5], [5, 15]] # test default data loader data = np.arange(np.product(shape)).reshape(shape) with h5py.File('tmp.h5') as f: f.create_dataset('data', data=data) data = data[data_slice] data = np.pad(data, pad_width=pad, mode='constant') loader = LazyHDF5VolumeLoader('tmp.h5', 'data', window_size=[20, 20, 20], stride=[20, 20, 20], return_index_spec=True, padding=pad, padding_mode='constant', data_slice=data_slice) self.assertEqual(loader.shape, data.shape) for batch, index in loader: slice_ = index.base_sequence_at_index expected = data[slice_] self.assertEqual(batch.shape, expected.shape) self.assertTrue(np.allclose(batch, expected)) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_io/test_volumetric/test_volume_loader.py ================================================ import unittest import os from shutil import rmtree import numpy as np import h5py class TestVolumeLoader(unittest.TestCase): shape = (100, 100, 100) def setUp(self): self.data = np.random.rand(*self.shape) def test_loader(self): from inferno.io.volumetric import VolumeLoader loader = VolumeLoader(self.data, window_size=(10, 10, 10), stride=(10, 10, 10), return_index_spec=True) for batch, idx in loader: slice_ = loader.base_sequence[int(idx)] expected = self.data[slice_] self.assertEqual(batch.shape, expected.shape) self.assertTrue(np.allclose(batch, expected)) class TestHDF5VolumeLoader(unittest.TestCase): shape = (100, 100, 100) def setUp(self): try: os.mkdir('./tmp') except OSError: pass self.data = np.random.rand(*self.shape) with h5py.File('./tmp/data.h5') as f: f.create_dataset('data', data=self.data) def tearDown(self): try: rmtree('./tmp') except OSError: pass def test_hdf5_loader(self): from inferno.io.volumetric import HDF5VolumeLoader loader = HDF5VolumeLoader('./tmp/data.h5', 'data', window_size=(10, 10, 10), stride=(10, 10, 10), return_index_spec=True) for batch, idx in loader: slice_ = loader.base_sequence[int(idx)] expected = self.data[slice_] self.assertEqual(batch.shape, expected.shape) self.assertTrue(np.allclose(batch, expected)) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_training/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_training/test_basic.py ================================================ from unittest import TestCase, skipUnless import torch from unittest import main import time from os.path import join, dirname class TestTrainer(TestCase): # Parameters ROOT_DIR = dirname(__file__) CUDA = False HALF_PRECISION = False DOWNLOAD_CIFAR = True @staticmethod def _make_test_model(): import torch.nn as nn from inferno.extensions.layers.reshape import AsMatrix toy_net = nn.Sequential(nn.Conv2d(3, 8, 3, 1, 1), nn.ELU(), nn.MaxPool2d(2), nn.Conv2d(8, 8, 3, 1, 1), nn.ELU(), nn.MaxPool2d(2), nn.Conv2d(8, 16, 3, 1, 1), nn.ELU(), nn.AdaptiveAvgPool2d((1, 1)), AsMatrix(), nn.Linear(16, 10)) return toy_net def test_cifar(self): from inferno.trainers.basic import Trainer from inferno.io.box.cifar import get_cifar10_loaders # Build cifar10 loaders trainloader, testloader = get_cifar10_loaders(root_directory=join(self.ROOT_DIR, 'data'), download=self.DOWNLOAD_CIFAR) # Make model net = self._make_test_model() tic = time.time() # Make trainer trainer = Trainer(model=net)\ .build_optimizer('Adam')\ .build_criterion('CrossEntropyLoss')\ .build_metric('CategoricalError')\ .validate_every((1, 'epochs'))\ .save_every((1, 'epochs'), to_directory=join(self.ROOT_DIR, 'saves'))\ .save_at_best_validation_score()\ .set_max_num_epochs(2) # Bind trainer to datasets trainer.bind_loader('train', trainloader).bind_loader('validate', testloader) # Check device and fit if self.CUDA: if self.HALF_PRECISION: trainer.cuda().set_precision('half').fit() else: trainer.cuda().fit() else: trainer.fit() toc = time.time() print("[*] Elapsed time: {} seconds.".format(toc - tic)) def test_multi_io(self): from torch.utils.data.dataset import Dataset from torch.utils.data.dataloader import DataLoader from inferno.trainers.basic import Trainer class DummyDataset(Dataset): def __len__(self): return 42 def __getitem__(self, item): # 2 inputs and 3 targets (say) return torch.rand(3, 32, 32), \ torch.rand(3, 32, 32), \ torch.rand(1).uniform_(), \ torch.rand(1).uniform_(), \ torch.rand(1).uniform_() class DummyNetwork(torch.nn.Module): def __init__(self): super(DummyNetwork, self).__init__() self.conv = torch.nn.Conv2d(3, 1, 3, padding=1) def forward(self, *inputs): assert len(inputs) == 2 out = self.conv(inputs[0]) return out.view(inputs[0].size(0), -1).mean(1), \ out.view(inputs[0].size(0), -1).mean(1), \ out.view(inputs[0].size(0), -1).mean(1) class DummyCriterion(torch.nn.Module): def forward(self, predictions, targets): assert len(predictions) == len(targets) == 3 return predictions[0].mean() loader = DataLoader(DummyDataset()) net = DummyNetwork() trainer = Trainer(net)\ .build_criterion(DummyCriterion)\ .build_optimizer('Adam')\ .set_max_num_iterations(50)\ .bind_loader('train', loader, num_inputs=2, num_targets=3) trainer.fit() def test_serialization(self): from inferno.trainers.basic import Trainer import os # Make model net = self._make_test_model() # Make trainer trainer = Trainer(model=net) \ .build_optimizer('Adam') \ .build_criterion('CrossEntropyLoss') \ .build_metric('CategoricalError') \ .validate_every((1, 'epochs')) \ .save_every((1, 'epochs'), to_directory=os.path.join(self.ROOT_DIR, 'saves')) \ .save_at_best_validation_score() \ .set_max_num_epochs(2) # Try to serialize trainer.save() # Try to unserialize trainer = Trainer(net).save_to_directory(os.path.join(self.ROOT_DIR, 'saves')).load() @skipUnless(torch.cuda.device_count() >= 4, "Not enough cuda devices for test_multi_gpu.") def test_multi_gpu(self): if not torch.cuda.is_available(): return from inferno.trainers.basic import Trainer from inferno.io.box.cifar import get_cifar10_loaders import os # Make model net = self._make_test_model() # Make trainer trainer = Trainer(model=net) \ .build_optimizer('Adam') \ .build_criterion('CrossEntropyLoss') \ .build_metric('CategoricalError') \ .validate_every((1, 'epochs')) \ .save_every((1, 'epochs'), to_directory=os.path.join(self.ROOT_DIR, 'saves')) \ .save_at_best_validation_score() \ .set_max_num_epochs(2)\ .cuda(devices=[0, 1, 2, 3], base_device='cpu') train_loader, validate_loader = get_cifar10_loaders(root_directory=self.ROOT_DIR, download=True) trainer.bind_loader('train', train_loader) trainer.bind_loader('validate', validate_loader) trainer.fit() def test_save(self): from inferno.trainers.basic import Trainer trainer = Trainer().save_to_directory(to_directory=self.ROOT_DIR, checkpoint_filename='dummy.pytorch') trainer.save() # Instantiate new trainer and load trainer = Trainer().load(from_directory=self.ROOT_DIR, filename='dummy.pytorch') @skipUnless(torch.cuda.device_count() >= 2, "Not enough cuda devices for test_multi_gpu_setup.") def test_multi_gpu_setup(self): from torch.nn import CrossEntropyLoss from inferno.trainers.basic import Trainer # Test base_device = 'cpu' # Build model net = self._make_test_model() # Make dummy criterion criterion = CrossEntropyLoss(weight=torch.rand(10)) # Make trainer trainer = Trainer(net).build_criterion(criterion).cuda([0, 1], base_device='cpu') self.assertIsInstance(trainer.criterion.weight, torch.FloatTensor) # Test base_device = 'cpu' # Build model net = self._make_test_model() criterion = CrossEntropyLoss(weight=torch.rand(10)) # Make trainer trainer = Trainer(net).build_criterion(criterion).cuda([0, 1], base_device='cuda') self.assertIsInstance(trainer.criterion.weight, torch.cuda.FloatTensor) if __name__ == '__main__': main() ================================================ FILE: tests/test_training/test_callbacks/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_training/test_callbacks/test_base.py ================================================ import unittest import torch from inferno.trainers.callbacks.base import Callback, CallbackEngine from inferno.trainers.basic import Trainer from os.path import join, dirname, exists from os import makedirs from shutil import rmtree class DummyCallback(Callback): def end_of_training_iteration(self, **_): assert self.trainer is not None class WrongDummyCallback(Callback): def end_of_iteration(self): pass class CallbackMechTest(unittest.TestCase): ROOT_DIR = join(dirname(__file__), 'root') def setUp(self): makedirs(self.ROOT_DIR, exist_ok=True) def tearDown(self): if exists(self.ROOT_DIR): rmtree(self.ROOT_DIR) def test_serialization(self): # Build engine and trainer callback_engine = CallbackEngine().bind_trainer(Trainer()) callback_engine.register_callback(DummyCallback()) # Serialize torch.save(callback_engine, join(self.ROOT_DIR, 'callback_engine.pkl')) # Unserialize callback_engine = torch.load(join(self.ROOT_DIR, 'callback_engine.pkl')) # Make sure the trainer is detached self.assertIsNone(callback_engine._trainer) self.assertIsInstance(next(iter(callback_engine ._callback_registry .get('end_of_training_iteration'))), DummyCallback) def test_auto_registry(self): callback_engine = CallbackEngine().bind_trainer(Trainer()) callback_engine.register_callback(DummyCallback()) self.assertIsInstance(next(iter(callback_engine ._callback_registry .get('end_of_training_iteration'))), DummyCallback) with self.assertRaises(AssertionError): callback_engine.register_callback(WrongDummyCallback()) def test_instance_registry(self): class Foo(Callback): pass class Bar(Callback): pass foo = Foo() bar = Bar() self.assertIs(foo.get_instances(), foo) self.assertIs(bar.get_instances(), bar) foo2 = Foo() self.assertSequenceEqual(foo2.get_instances(), [foo, foo2]) self.assertIs(bar.get_instances(), bar) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_training/test_callbacks/test_essentials.py ================================================ import unittest import shutil import h5py as h5 from os.path import dirname, join from os import listdir from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.essentials import DumpHDF5Every from inferno.utils.test_utils import generate_random_dataloader from inferno.extensions.layers import Conv2D, AsMatrix from torch.nn import Sequential, MaxPool2d, AdaptiveAvgPool2d, Linear, Softmax class TestEssentials(unittest.TestCase): WORKING_DIRECTORY = dirname(__file__) def setUp(self): # Build a simple ass model model = Sequential(Conv2D(3, 8, 3, activation='ReLU'), MaxPool2d(2, 2), Conv2D(8, 8, 3, activation='ReLU'), MaxPool2d(2, 2), Conv2D(8, 8, 3, activation='ReLU'), MaxPool2d(2, 2), Conv2D(8, 8, 3, activation='ReLU'), AdaptiveAvgPool2d((1, 1)), AsMatrix(), Linear(8, 10)) train_dataloader = generate_random_dataloader(512, (3, 32, 32), 10, batch_size=16, dtype='float32') validate_dataloader = generate_random_dataloader(32, (3, 32, 32), 10, batch_size=16, dtype='float32') # Build trainer trainer = Trainer(model)\ .bind_loader('train', train_dataloader)\ .bind_loader('validate', validate_dataloader)\ .save_to_directory(to_directory=join(self.WORKING_DIRECTORY, 'Weights'))\ .build_criterion('CrossEntropyLoss').build_optimizer('RMSprop') self.trainer = trainer def test_dump_hdf5_every(self): # Configure callback dumper = DumpHDF5Every((1, 'epoch'), to_directory=join(self.WORKING_DIRECTORY, 'Weights'), dump_after_every_validation_run=True) self.trainer\ .set_max_num_epochs(4)\ .register_callback(dumper)\ .validate_every((16, 'iterations')) self.trainer.fit() all_files = listdir(join(self.WORKING_DIRECTORY, 'Weights')) for epoch in range(5): self.assertIn('dump.training.epoch{}.iteration{}.h5'.format(epoch, epoch * 32), all_files) # We don't validate at last epoch if epoch != 4: self.assertIn('dump.validation.epoch{}.iteration{}.h5' .format(epoch, (epoch * 32) + 16), all_files) self.assertIn('dump.validation.epoch{}.iteration{}.h5' .format(epoch, (epoch * 32) + 32), all_files) # Check if the keys are right in a training dump sample_file_path = join(self.WORKING_DIRECTORY, 'Weights', 'dump.training.epoch0.iteration0.h5') with h5.File(sample_file_path, 'r') as sample_file: all_dataset_names = list(sample_file.keys()) self.assertSequenceEqual(all_dataset_names, ['training_inputs_0', 'training_prediction', 'training_target']) # Check if the keys are right in a validation dump sample_file_path = join(self.WORKING_DIRECTORY, 'Weights', 'dump.validation.epoch0.iteration16.h5') with h5.File(sample_file_path, 'r') as sample_file: all_dataset_names = list(sample_file.keys()) self.assertSequenceEqual(all_dataset_names, ['validation_inputs_0', 'validation_prediction', 'validation_target']) def tearDown(self): shutil.rmtree(join(self.WORKING_DIRECTORY, 'Weights')) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_training/test_callbacks/test_logging/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_training/test_callbacks/test_logging/test_base.py ================================================ import unittest from inferno.trainers.callbacks.logging.base import Logger from inferno.trainers.basic import Trainer from os.path import join, dirname class DummyLogger(Logger): def end_of_training_iteration(self, **_): pass class TestLogger(unittest.TestCase): ROOT = dirname(__file__) def test_serialization(self): trainer = Trainer()\ .build_logger(logger=DummyLogger())\ .save_to_directory(join(self.ROOT, 'saves')) trainer.save() # Unserialize trainer = Trainer().load(from_directory=join(self.ROOT, 'saves')) # Check if the loggers are consistent logger_from_trainer = trainer._logger logger_from_callback_engine = \ next(iter(trainer.callbacks._callback_registry['end_of_training_iteration'])) self.assertIs(logger_from_trainer, logger_from_callback_engine) self.assertIs(logger_from_callback_engine.trainer, trainer) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_training/test_callbacks/test_logging/test_tensorboard.py ================================================ import unittest import os from shutil import rmtree import numpy as np import torch import torch.nn as nn from inferno.trainers.basic import Trainer from torch.utils.data.dataset import TensorDataset from torch.utils.data.dataloader import DataLoader from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger from inferno.extensions.layers.reshape import AsMatrix class TestTensorboard(unittest.TestCase): ROOT_DIR = os.path.dirname(__file__) PRECISION = 'float' SAVE_DIRECTORY = os.path.join(ROOT_DIR, 'saves') LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs') @staticmethod def _make_test_model(input_channels): toy_net = nn.Sequential(nn.Conv2d(input_channels, 8, 3, 1, 1), nn.ELU(), nn.MaxPool2d(2), nn.Conv2d(8, 8, 3, 1, 1), nn.ELU(), nn.MaxPool2d(2), nn.Conv2d(8, 16, 3, 1, 1), nn.ELU(), nn.AdaptiveMaxPool2d((1, 1)), AsMatrix(), nn.Linear(16, 10)) return toy_net def tearDown(self): for d in [self.SAVE_DIRECTORY, self.LOG_DIRECTORY]: try: rmtree(d) except OSError: pass def get_random_dataloaders(self, input_channels=3): # Convert build random tensor dataset data_shape = (1, input_channels, 64, 64) target_shape = (1) random_array = torch.from_numpy(np.random.rand(*data_shape)).float() target_array = torch.from_numpy(np.random.randint(0, 9, size=target_shape)) train_dataset = TensorDataset(random_array, target_array) test_dataset = TensorDataset(random_array, target_array) # Build dataloaders from dataset train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=False) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=False) return train_loader, test_loader def get_trainer(self, input_channels): # Build model net = self._make_test_model(input_channels) # Build trainer trainer = Trainer(net)\ .build_logger(TensorboardLogger(send_image_at_batch_indices=0, send_image_at_channel_indices='all', log_images_every=(20, 'iterations')), log_directory=self.LOG_DIRECTORY)\ .build_criterion('CrossEntropyLoss')\ .build_metric('CategoricalError')\ .build_optimizer('Adam')\ .validate_every((1, 'epochs'))\ .save_every((2, 'epochs'), to_directory=self.SAVE_DIRECTORY)\ .save_at_best_validation_score()\ .set_max_num_epochs(2)\ .set_precision(self.PRECISION) # Bind loaders train_loader, test_loader = self.get_random_dataloaders(input_channels=input_channels) trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader) return trainer def test_tensorboard(self): trainer = self.get_trainer(3) trainer.fit() def test_tensorboard_grayscale(self): trainer = self.get_trainer(1) trainer.fit() def test_serialization(self): trainer = self.get_trainer(3) # Serialize trainer.save() # Unserialize trainer = Trainer().load(os.path.join(self.ROOT_DIR, 'saves')) train_loader, test_loader = self.get_random_dataloaders(input_channels=3) trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader) trainer.fit() if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_training/test_callbacks/test_scheduling.py ================================================ import unittest from inferno.trainers.callbacks.scheduling import ManualLR from torch import nn from torch.optim import Adam class TestSchedulers(unittest.TestCase): def test_manual_lr(self): class DummyTrainer(object): def __init__(self): self.iteration_count = 0 self.epoch_count = 0 self.optimizer = Adam(nn.Linear(10, 10).parameters(), lr=1.) manual_lr = ManualLR([((100, 'iterations'), 0.5), ((200, 'iterations'), 0.5), ((200, 'iterations'), 0.1)]) trainer = DummyTrainer() manual_lr._trainer = trainer manual_lr.end_of_training_iteration() self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 1.) trainer.iteration_count = 100 manual_lr.end_of_training_iteration() self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.5) trainer.iteration_count = 200 manual_lr.end_of_training_iteration() self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025) trainer.iteration_count = 300 self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_utils/__init__.py ================================================ # -*- coding: utf-8 -*- """Unit test package for inferno.""" ================================================ FILE: tests/test_utils/test_model_utils.py ================================================ import unittest import inferno.utils.model_utils as mu from inferno.utils.exceptions import ShapeError import torch import torch.nn as nn class ModelUtilTester(unittest.TestCase): def test_model_tester(self): model = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32))(nn.Conv2d(10, 20, 3, padding=1)) with self.assertRaises(ShapeError): mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32))(model) @unittest.skipUnless(torch.cuda.is_available(), "need cuda") def test_model_tester_cuda(self): tester = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32)).cuda() model = tester(nn.Conv2d(10, 20, 3, padding=1).cuda()) with self.assertRaises(ShapeError): mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32)).cuda()(model) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_utils/test_partial_cls.py ================================================ import unittest import inferno.utils.model_utils as mu from inferno.utils.partial_cls import register_partial_cls import torch import torch.nn as nn class TestCls(object): def __init__(self, a, b, c=1, d=2): self.a = a self.b = b self.c = c self.d = d class PartialClsTester(unittest.TestCase): def test_partial_cls(self): register_partial_cls(TestCls, 'TestA', fix=dict(a='a'), default=dict(b='b'), module=__name__ ) assert 'TestA' in globals() inst = TestA() assert inst.a == 'a' assert inst.b == 'b' assert inst.c == 1 assert inst.d == 2 inst = TestA('fu','bar','fubar') assert inst.a == 'a' assert inst.b == 'fu' assert inst.c == 'bar' assert inst.d == 'fubar' with self.assertRaises(TypeError): inst = TestA(a=2) def test_update_existing_default_cls(self): register_partial_cls(TestCls, 'TestA', fix=dict(a='a'), default=dict(d=3), module=__name__ ) assert 'TestA' in globals() inst = TestA(42) assert inst.a == 'a' assert inst.b == 42 assert inst.c == 1 assert inst.d == 3 with self.assertRaises(TypeError): inst = TestA() def test_fix_nothing(self): register_partial_cls(TestCls, 'TestA', module=__name__ ) assert 'TestA' in globals() inst = TestA(1,2,3,4) assert inst.a == 1 assert inst.b == 2 assert inst.c == 3 assert inst.d == 4 with self.assertRaises(TypeError): inst = TestA() def test_fix_all(self): register_partial_cls(TestCls, 'TestA', module=__name__, fix=dict(a=4, b=3, c=2, d=1) ) assert 'TestA' in globals() inst = TestA() assert inst.a == 4 assert inst.b == 3 assert inst.c == 2 assert inst.d == 1 with self.assertRaises(TypeError): inst = TestA('a') with self.assertRaises(TypeError): inst = TestA(a=1) with self.assertRaises(TypeError): inst = TestA(b=1) with self.assertRaises(TypeError): inst = TestA(c=1) with self.assertRaises(TypeError): inst = TestA(d=1) def test_default_all(self): register_partial_cls(TestCls, 'TestA', module=__name__, default=dict(a=4, b=3, c=2, d=1) ) assert 'TestA' in globals() inst = TestA() assert inst.a == 4 assert inst.b == 3 assert inst.c == 2 assert inst.d == 1 inst = TestA(2) assert inst.a == 2 assert inst.b == 3 assert inst.c == 2 assert inst.d == 1 inst = TestA(2,3,4,5) assert inst.a == 2 assert inst.b == 3 assert inst.c == 4 assert inst.d == 5 with self.assertRaises(TypeError): inst = TestA(3,4,5,a=2) inst = TestA(3,4,5,d=2) assert inst.a == 3 assert inst.b == 4 assert inst.c == 5 assert inst.d == 2 if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_utils/test_train_utils.py ================================================ import unittest import inferno.utils.train_utils as tu import numpy as np class FrequencyTest(unittest.TestCase): def test_from_string(self): frequency = tu.Frequency.from_string('10 epochs') self.assertFalse(frequency.match(epoch_count=9)) self.assertTrue(frequency.match(epoch_count=10)) frequency = tu.Frequency.from_string('1 iteration') self.assertEqual(frequency.units, 'iterations') self.assertTrue(frequency.match(iteration_count=10)) frequency = tu.Frequency.from_string('never') self.assertFalse(frequency.match(epoch_count=9)) frequency = tu.Frequency.from_string('inf epochs') self.assertFalse(frequency.match(epoch_count=9)) def test_from_tuple(self): frequency = tu.Frequency.build_from((np.inf, 'epoch')) self.assertFalse(frequency.match(epoch_count=9)) self.assertFalse(frequency.match(epoch_count=10)) def test_is_consistent(self): frequency = tu.Frequency.build_from('10 epochs') frequency._units = 'banana' self.assertFalse(frequency.is_consistent) def test_init(self): frequency = tu.Frequency() self.assertEqual(frequency.value, np.inf) self.assertEqual(frequency.units, frequency.UNIT_PRIORITY) def test_duration(self): duration = tu.Duration.build_from((3, 'iterations')) self.assertFalse(duration.match(iteration_count=2)) self.assertFalse(duration.match(iteration_count=3)) self.assertTrue(duration.match(iteration_count=3, when_equal_return=True)) self.assertTrue(duration.match(iteration_count=4)) self.assertEqual(duration.compare(iteration_count=1, epoch_count=3).get('iterations'), 2) with self.assertRaises(ValueError): duration.match(epoch_count=2) if __name__ == '__main__': unittest.main()